diff --git a/algorithms.py b/algorithms.py index 55fa82c66..b2d90bdb3 100644 --- a/algorithms.py +++ b/algorithms.py @@ -549,6 +549,7 @@ def __init__( gene_conversion_rate=0.0, gene_conversion_length=1, discrete_genome=True, + stop_condition=None, ): # Must be a square matrix. N = len(migration_matrix) @@ -640,6 +641,7 @@ def __init__( for time in census_times: self.modifier_events.append((time[0], self.census_event, time)) self.modifier_events.sort() + self.stop_condition = stop_condition def initialise(self, ts): root_time = np.max(self.tables.nodes.time) @@ -684,12 +686,34 @@ def initialise(self, ts): self.set_segment_mass(seg) seg = seg.next + def get_num_ancestors(self): + return sum(pop.get_num_ancestors() for pop in self.P) + def ancestors_remain(self): """ Returns True if the simulation is not finished, i.e., there is some ancestral material that has not fully coalesced. """ - return sum(pop.get_num_ancestors() for pop in self.P) != 0 + return self.get_num_ancestors() != 0 + + def assert_stop_condition(self): + """ + Returns true if the simulation is not finished given the global + stopping condition that was specified. + """ + if self.stop_condition is None: + return self.ancestors_remain() + elif self.stop_condition == "grand_mrca": + return self.get_num_ancestors() > 1 + elif self.stop_condition == "all_local_mrcas": + return any(num_anc > 1 for num_anc in self.S.values()) + elif self.stop_condition == "time": + return True + elif self.stop_condition == "pedigree": + return True + else: + print("Error: unknown stop condition-", self.stop_condition) + raise ValueError def change_population_size(self, pop_id, size): self.P[pop_id].set_start_size(size) @@ -835,16 +859,19 @@ def finalise(self): def simulate(self, end_time): self.verify() if self.model == "hudson": - self.hudson_simulate(end_time) + ret = self.hudson_simulate(end_time) elif self.model == "dtwf": - self.dtwf_simulate() + ret = self.dtwf_simulate(end_time) elif self.model == "fixed_pedigree": - self.pedigree_simulate() + ret = self.pedigree_simulate(end_time) elif self.model == "single_sweep": - self.single_sweep_simulate() + ret = self.single_sweep_simulate() else: print("Error: bad model specification -", self.model) raise ValueError + + if ret == 2: # _msprime.EXIT_MAX_TIME: + self.t = end_time return self.finalise() def get_potential_destinations(self): @@ -897,15 +924,14 @@ def hudson_simulate(self, end_time): """ Simulates the algorithm until all loci have coalesced. """ + ret = 0 infinity = sys.float_info.max non_empty_pops = {pop.id for pop in self.P if pop.get_num_ancestors() > 0} potential_destinations = self.get_potential_destinations() # only worried about label 0 below - while len(non_empty_pops) > 0: + while self.assert_stop_condition(): self.verify() - if self.t >= end_time: - break # self.print_state() re_rate = self.get_total_recombination_rate(label=0) t_re = infinity @@ -948,6 +974,9 @@ def hudson_simulate(self, end_time): mig_dest = k min_time = min(t_re, t_ca, t_gcin, t_gc_left, t_mig) assert min_time != infinity + if self.t + min_time > end_time: + ret = 2 # _msprime.MAX_EVENT_TIME + break if self.t + min_time > self.modifier_events[0][0]: t, func, args = self.modifier_events.pop(0) self.t = t @@ -992,6 +1021,8 @@ def hudson_simulate(self, end_time): X = {pop.id for pop in self.P if pop.get_num_ancestors() > 0} assert non_empty_pops == X + return ret + def single_sweep_simulate(self): """ Does a structed coalescent until end_freq is reached, using @@ -1099,21 +1130,27 @@ def single_sweep_simulate(self): self.set_labels(u, 0) self.P[0].add(tmp) - def pedigree_simulate(self): + def pedigree_simulate(self, end_time): """ Simulates through the provided pedigree, stopping at the top. """ self.pedigree = Pedigree(self.tables) - self.dtwf_climb_pedigree() + ret = self.dtwf_climb_pedigree(end_time) + return ret - def dtwf_simulate(self): + def dtwf_simulate(self, end_time): """ Simulates the algorithm until all loci have coalesced. """ - while self.ancestors_remain(): + ret = 0 + while self.assert_stop_condition(): + if self.t + 1 > end_time: + ret = 2 # _msprime.EXIT_MAX_TIME + break self.t += 1 self.verify() self.dtwf_generation() + return ret def dtwf_generation(self): """ @@ -1251,13 +1288,14 @@ def process_pedigree_common_ancestors(self, ind, ploid): self.flush_edges() self.verify() - def dtwf_climb_pedigree(self): + def dtwf_climb_pedigree(self, end_time): """ Simulates transmission of ancestral material through a pre-specified pedigree """ assert self.num_populations == 1 # Single pop/pedigree for now pop = self.P[0] + ret = 0 # Go through the extant lineages and gather the ancestral material # into the corresponding pedigree individuals. @@ -1270,9 +1308,13 @@ def dtwf_climb_pedigree(self): # Visit pedigree individuals in time order. visit_order = sorted(self.pedigree.individuals, key=lambda x: (x.time, x.id)) for ind in visit_order: + if ind.time > end_time: + ret = 2 # _msprime.EXIT_MAX_TIME + break self.t = ind.time for ploid in range(ind.ploidy): self.process_pedigree_common_ancestors(ind, ploid) + return ret def store_arg_edges(self, segment, u=-1): if u == -1: @@ -1810,13 +1852,19 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): j = self.S.floor_key(r_max) self.S[r_max] = self.S[j] # Update the number of extant segments. - if self.S[left] == len(X): + if self.S[left] == len(X) and self.stop_condition is None: self.S[left] = 0 right = self.S.succ_key(left) else: right = left - while right < r_max and self.S[right] != len(X): - self.S[right] -= len(X) - 1 + while right < r_max: + if self.S[right] <= len(X): + if self.stop_condition is None: + break + else: + self.S[right] = 1 + else: + self.S[right] -= len(X) - 1 right = self.S.succ_key(right) alpha = self.alloc_segment(left, right, new_node_id, pop_id) # Update the heaps and make the record. @@ -1956,13 +2004,19 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): j = self.S.floor_key(r_max) self.S[r_max] = self.S[j] # Update the number of extant segments. - if self.S[left] == 2: + if self.S[left] == 2 and self.stop_condition is None: self.S[left] = 0 right = self.S.succ_key(left) else: right = left - while right < r_max and self.S[right] != 2: - self.S[right] -= 1 + while right < r_max: + if self.S[right] <= 2: + if self.stop_condition is None: + break + else: + self.S[right] = 1 + else: + self.S[right] -= 1 right = self.S.succ_key(right) alpha = self.alloc_segment( left=left, @@ -2253,6 +2307,10 @@ def run_simulate(args): else: from_ts = tskit.load(args.from_ts) tables = from_ts.dump_tables() + if args.stop_condition == "pedigree": + end_time = np.max(from_ts.nodes_time) + else: + end_time = args.end_time s = Simulator( tables=tables, @@ -2275,8 +2333,9 @@ def run_simulate(args): gene_conversion_rate=gc_rate, gene_conversion_length=mean_tract_length, discrete_genome=args.discrete, + stop_condition=args.stop_condition, ) - ts = s.simulate(args.end_time) + ts = s.simulate(end_time) ts.dump(args.output_file) if args.verbose: s.print_state() @@ -2373,6 +2432,9 @@ def add_simulator_arguments(parser): parser.add_argument( "--end-time", type=float, default=np.inf, help="The end for simulations." ) + parser.add_argument( + "--stop-condition", type=str, default=None, help="Global stopping condition" + ) def main(args=None): diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index ce48c4778..d62e1d16f 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -431,3 +431,53 @@ def test_one_gen_pedigree(self, num_founders): tables.dump(ts_path) ts = self.run_script(f"0 --from-ts {ts_path} -r 1 --model=fixed_pedigree") assert len(ts.dump_tables().edges) == 0 + + def test_stopping_condition_time(self): + end_time = 2.5 + ts = self.run_script(f"10 --stop-condition=time --end-time={end_time}") + assert ts.max_root_time == end_time + assert ts.num_samples == 10 + assert ts.num_trees > 1 + assert not has_discrete_genome(ts) + assert ts.sequence_length == 100 + + def test_stopping_condition_all_mrcas(self): + ts = self.run_script("10 --stop-condition=all_local_mrcas") + roots = [tree.root for tree in ts.trees()] + assert len(set(roots)) > 1 + roots_time = [tree.time(tree.root) for tree in ts.trees()] + assert len(set(roots_time)) == 1 + assert ts.num_samples == 10 + assert ts.num_trees > 1 + assert not has_discrete_genome(ts) + assert ts.sequence_length == 100 + + def test_stopping_condition_grand_mrca(self): + ts = self.run_script("10 --stop-condition=grand_mrca") + assert ts.num_trees > 1 + roots = [tree.root for tree in ts.trees()] + assert len(set(roots)) == 1 + + def test_stopping_condition_pedigree(self): + num_founders = 4 + num_generations = 10 + tables = simulate_pedigree( + num_founders=num_founders, num_generations=num_generations + ) + with tempfile.TemporaryDirectory() as tmpdir: + ts_path = pathlib.Path(tmpdir) / "pedigree.trees" + tables.dump(ts_path) + ts = self.run_script( + f"0 --from-ts {ts_path} --model=fixed_pedigree -r 0.1 \ + --stop-condition=pedigree" + ) + assert ts.num_trees > 1 + assert ts.max_root_time == num_generations - 1 + + def test_stopping_condition_dtwf(self): + end_time = 20 + ts = self.run_script( + f"10 --model=dtwf --stop-condition=time --end-time={end_time}" + ) + assert ts.num_trees > 1 + assert ts.max_root_time == end_time