Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

global stop condition #2248

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 82 additions & 20 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def __init__(
gene_conversion_rate=0.0,
gene_conversion_length=1,
discrete_genome=True,
stop_condition=None,
):
# Must be a square matrix.
N = len(migration_matrix)
Expand Down Expand Up @@ -640,6 +641,7 @@ def __init__(
for time in census_times:
self.modifier_events.append((time[0], self.census_event, time))
self.modifier_events.sort()
self.stop_condition = stop_condition

def initialise(self, ts):
root_time = np.max(self.tables.nodes.time)
Expand Down Expand Up @@ -684,12 +686,34 @@ def initialise(self, ts):
self.set_segment_mass(seg)
seg = seg.next

def get_num_ancestors(self):
return sum(pop.get_num_ancestors() for pop in self.P)

def ancestors_remain(self):
"""
Returns True if the simulation is not finished, i.e., there is some ancestral
material that has not fully coalesced.
"""
return sum(pop.get_num_ancestors() for pop in self.P) != 0
return self.get_num_ancestors() != 0

def assert_stop_condition(self):
"""
Returns true if the simulation is not finished given the global
stopping condition that was specified.
"""
if self.stop_condition is None:
return self.ancestors_remain()
elif self.stop_condition == "grand_mrca":
return self.get_num_ancestors() > 1
elif self.stop_condition == "all_local_mrcas":
return any(num_anc > 1 for num_anc in self.S.values())
elif self.stop_condition == "time":
return True
elif self.stop_condition == "pedigree":
return True
else:
print("Error: unknown stop condition-", self.stop_condition)
raise ValueError

def change_population_size(self, pop_id, size):
self.P[pop_id].set_start_size(size)
Expand Down Expand Up @@ -835,16 +859,19 @@ def finalise(self):
def simulate(self, end_time):
self.verify()
if self.model == "hudson":
self.hudson_simulate(end_time)
ret = self.hudson_simulate(end_time)
elif self.model == "dtwf":
self.dtwf_simulate()
ret = self.dtwf_simulate(end_time)
elif self.model == "fixed_pedigree":
self.pedigree_simulate()
ret = self.pedigree_simulate(end_time)
elif self.model == "single_sweep":
self.single_sweep_simulate()
ret = self.single_sweep_simulate()
else:
print("Error: bad model specification -", self.model)
raise ValueError

if ret == 2: # _msprime.EXIT_MAX_TIME:
self.t = end_time
return self.finalise()

def get_potential_destinations(self):
Expand Down Expand Up @@ -897,15 +924,14 @@ def hudson_simulate(self, end_time):
"""
Simulates the algorithm until all loci have coalesced.
"""
ret = 0
infinity = sys.float_info.max
non_empty_pops = {pop.id for pop in self.P if pop.get_num_ancestors() > 0}
potential_destinations = self.get_potential_destinations()

# only worried about label 0 below
while len(non_empty_pops) > 0:
while self.assert_stop_condition():
self.verify()
if self.t >= end_time:
break
# self.print_state()
re_rate = self.get_total_recombination_rate(label=0)
t_re = infinity
Expand Down Expand Up @@ -948,6 +974,9 @@ def hudson_simulate(self, end_time):
mig_dest = k
min_time = min(t_re, t_ca, t_gcin, t_gc_left, t_mig)
assert min_time != infinity
if self.t + min_time > end_time:
ret = 2 # _msprime.MAX_EVENT_TIME
break
if self.t + min_time > self.modifier_events[0][0]:
t, func, args = self.modifier_events.pop(0)
self.t = t
Expand Down Expand Up @@ -992,6 +1021,8 @@ def hudson_simulate(self, end_time):
X = {pop.id for pop in self.P if pop.get_num_ancestors() > 0}
assert non_empty_pops == X

return ret

def single_sweep_simulate(self):
"""
Does a structed coalescent until end_freq is reached, using
Expand Down Expand Up @@ -1099,21 +1130,27 @@ def single_sweep_simulate(self):
self.set_labels(u, 0)
self.P[0].add(tmp)

def pedigree_simulate(self):
def pedigree_simulate(self, end_time):
"""
Simulates through the provided pedigree, stopping at the top.
"""
self.pedigree = Pedigree(self.tables)
self.dtwf_climb_pedigree()
ret = self.dtwf_climb_pedigree(end_time)
return ret

def dtwf_simulate(self):
def dtwf_simulate(self, end_time):
"""
Simulates the algorithm until all loci have coalesced.
"""
while self.ancestors_remain():
ret = 0
while self.assert_stop_condition():
if self.t + 1 > end_time:
ret = 2 # _msprime.EXIT_MAX_TIME
break
self.t += 1
self.verify()
self.dtwf_generation()
return ret

def dtwf_generation(self):
"""
Expand Down Expand Up @@ -1251,13 +1288,14 @@ def process_pedigree_common_ancestors(self, ind, ploid):
self.flush_edges()
self.verify()

def dtwf_climb_pedigree(self):
def dtwf_climb_pedigree(self, end_time):
"""
Simulates transmission of ancestral material through a pre-specified
pedigree
"""
assert self.num_populations == 1 # Single pop/pedigree for now
pop = self.P[0]
ret = 0

# Go through the extant lineages and gather the ancestral material
# into the corresponding pedigree individuals.
Expand All @@ -1270,9 +1308,13 @@ def dtwf_climb_pedigree(self):
# Visit pedigree individuals in time order.
visit_order = sorted(self.pedigree.individuals, key=lambda x: (x.time, x.id))
for ind in visit_order:
if ind.time > end_time:
ret = 2 # _msprime.EXIT_MAX_TIME
break
self.t = ind.time
for ploid in range(ind.ploidy):
self.process_pedigree_common_ancestors(ind, ploid)
return ret

def store_arg_edges(self, segment, u=-1):
if u == -1:
Expand Down Expand Up @@ -1810,13 +1852,19 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
j = self.S.floor_key(r_max)
self.S[r_max] = self.S[j]
# Update the number of extant segments.
if self.S[left] == len(X):
if self.S[left] == len(X) and self.stop_condition is None:
self.S[left] = 0
right = self.S.succ_key(left)
else:
right = left
while right < r_max and self.S[right] != len(X):
self.S[right] -= len(X) - 1
while right < r_max:
if self.S[right] <= len(X):
if self.stop_condition is None:
break
else:
self.S[right] = 1
else:
self.S[right] -= len(X) - 1
right = self.S.succ_key(right)
alpha = self.alloc_segment(left, right, new_node_id, pop_id)
# Update the heaps and make the record.
Expand Down Expand Up @@ -1956,13 +2004,19 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
j = self.S.floor_key(r_max)
self.S[r_max] = self.S[j]
# Update the number of extant segments.
if self.S[left] == 2:
if self.S[left] == 2 and self.stop_condition is None:
self.S[left] = 0
right = self.S.succ_key(left)
else:
right = left
while right < r_max and self.S[right] != 2:
self.S[right] -= 1
while right < r_max:
if self.S[right] <= 2:
if self.stop_condition is None:
break
else:
self.S[right] = 1
else:
self.S[right] -= 1
right = self.S.succ_key(right)
alpha = self.alloc_segment(
left=left,
Expand Down Expand Up @@ -2253,6 +2307,10 @@ def run_simulate(args):
else:
from_ts = tskit.load(args.from_ts)
tables = from_ts.dump_tables()
if args.stop_condition == "pedigree":
end_time = np.max(from_ts.nodes_time)
else:
end_time = args.end_time

s = Simulator(
tables=tables,
Expand All @@ -2275,8 +2333,9 @@ def run_simulate(args):
gene_conversion_rate=gc_rate,
gene_conversion_length=mean_tract_length,
discrete_genome=args.discrete,
stop_condition=args.stop_condition,
)
ts = s.simulate(args.end_time)
ts = s.simulate(end_time)
ts.dump(args.output_file)
if args.verbose:
s.print_state()
Expand Down Expand Up @@ -2373,6 +2432,9 @@ def add_simulator_arguments(parser):
parser.add_argument(
"--end-time", type=float, default=np.inf, help="The end for simulations."
)
parser.add_argument(
"--stop-condition", type=str, default=None, help="Global stopping condition"
)


def main(args=None):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,53 @@ def test_one_gen_pedigree(self, num_founders):
tables.dump(ts_path)
ts = self.run_script(f"0 --from-ts {ts_path} -r 1 --model=fixed_pedigree")
assert len(ts.dump_tables().edges) == 0

def test_stopping_condition_time(self):
end_time = 2.5
ts = self.run_script(f"10 --stop-condition=time --end-time={end_time}")
assert ts.max_root_time == end_time
assert ts.num_samples == 10
assert ts.num_trees > 1
assert not has_discrete_genome(ts)
assert ts.sequence_length == 100

def test_stopping_condition_all_mrcas(self):
ts = self.run_script("10 --stop-condition=all_local_mrcas")
roots = [tree.root for tree in ts.trees()]
assert len(set(roots)) > 1
roots_time = [tree.time(tree.root) for tree in ts.trees()]
assert len(set(roots_time)) == 1
assert ts.num_samples == 10
assert ts.num_trees > 1
assert not has_discrete_genome(ts)
assert ts.sequence_length == 100

def test_stopping_condition_grand_mrca(self):
ts = self.run_script("10 --stop-condition=grand_mrca")
assert ts.num_trees > 1
roots = [tree.root for tree in ts.trees()]
assert len(set(roots)) == 1

def test_stopping_condition_pedigree(self):
num_founders = 4
num_generations = 10
tables = simulate_pedigree(
num_founders=num_founders, num_generations=num_generations
)
with tempfile.TemporaryDirectory() as tmpdir:
ts_path = pathlib.Path(tmpdir) / "pedigree.trees"
tables.dump(ts_path)
ts = self.run_script(
f"0 --from-ts {ts_path} --model=fixed_pedigree -r 0.1 \
--stop-condition=pedigree"
)
assert ts.num_trees > 1
assert ts.max_root_time == num_generations - 1

def test_stopping_condition_dtwf(self):
end_time = 20
ts = self.run_script(
f"10 --model=dtwf --stop-condition=time --end-time={end_time}"
)
assert ts.num_trees > 1
assert ts.max_root_time == end_time
Loading