Skip to content

Commit 9529662

Browse files
committed
ageBasedContactsSampling contact matrix can be of any size
1 parent d2e0b82 commit 9529662

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/methods/contact_sampling_methods.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ function sample_contacts(contactparameter_sampling::AgeBasedContactSampling, set
105105
return Individual[]
106106
end
107107
interval = contactparameter_sampling.contact_matrix.interval_steps
108-
max_age = contactparameter_sampling.contact_matrix.aggregation_bound
109-
orig_bin = (individual.age ÷ interval) + 1
108+
max_age = contactparameter_sampling.contact_matrix.aggregation_bound - 1
109+
orig_bin = (min(individual.age, max_age) ÷ interval) + 1
110110
contact_matrix::Matrix{Float64} = contactparameter_sampling.contact_matrix.data
111111
age_pyramid = contactparameter_sampling.age_pyramid
112112
# if age_pyramid is not ready compute it
113113
if size(age_pyramid)[1] == 0
114114
age_pyramid = zeros(size(contact_matrix)[1])
115115
for ind in present_inds
116-
interval_id = ind.age ÷ interval + 1
116+
interval_id = min(ind.age, max_age) ÷ interval + 1
117117
age_pyramid[interval_id] += 1
118118
end
119119
age_pyramid = age_pyramid ./ sum(age_pyramid)
@@ -155,7 +155,8 @@ function sample_contacts(contactparameter_sampling::AgeBasedContactSampling, set
155155
# Second order sampling (i.e. structural one)
156156
out = Individual[]
157157
for i = 1:number_of_contacts
158-
dest_bin = (res[i].age ÷ interval) + 1
158+
159+
dest_bin = (min(res[i].age, max_age) ÷ interval) + 1
159160
m = contact_matrix[orig_bin, dest_bin]
160161
if m > 0.0
161162
m = m / m_max # since we multiplied by m_max in line no. 113

src/structs/parameters/contact_sampling_method_structs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ mutable struct AgeBasedContactSampling <: ContactSamplingMethod
8686
throw(ArgumentError("Sum of row $i in 'contact_matrix' is $s, but the sum has to be equal to 1.0!"))
8787
end
8888
end
89-
contact_matrix = ContactMatrix{Float64}(matrix, interval)
89+
aggregation_bound = size(matrix)[1] * interval
90+
contact_matrix = ContactMatrix{Float64}(matrix, interval, aggregation_bound)
9091
return new(contactparameter, interval, contact_matrix, Float64[])
9192
end
9293

0 commit comments

Comments
 (0)