From 3777baba12ddbaf121032585e8eb6dce3b720702 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 23 Jul 2024 15:40:29 +0100 Subject: [PATCH 1/3] Refactor algorithms.py code for basic lineage struct Cover all the easy cases Get SMCK working Add a "lineage" attr to Segment and fixup sweeps code Basic DTWF seems to work Fixup algorithms minimally --- algorithms.py | 244 +++++++++++++++++++++++---------------- tests/test_algorithms.py | 40 +++---- 2 files changed, 156 insertions(+), 128 deletions(-) diff --git a/algorithms.py b/algorithms.py index 5925b16b2..dec9976e7 100644 --- a/algorithms.py +++ b/algorithms.py @@ -2,6 +2,7 @@ Python version of the simulation algorithm. """ import argparse +import dataclasses import heapq import itertools import logging @@ -127,6 +128,7 @@ def __init__(self, index): self.label = 0 self.index = index self.hull = None + self.lineage = None def __repr__(self): return repr((self.left, self.right, self.node)) @@ -164,6 +166,11 @@ def get_left_index(self): return index +@dataclasses.dataclass +class Lineage: + head: Segment + + class Population: """ Class representing a population in the simulation. @@ -204,8 +211,8 @@ def print_state(self): print("\tAncestors: ", len(self._ancestors)) for label, ancestors in enumerate(self._ancestors): print("\tLabel = ", label) - for u in ancestors: - print("\t\t" + Segment.show_chain(u)) + for lineage in ancestors: + print("\t\t" + Segment.show_chain(lineage.head)) def set_growth_rate(self, growth_rate, time): # TODO This doesn't work because we need to know what the time @@ -366,6 +373,7 @@ def remove_individual(self, individual, label=0): """ Removes the given individual from its population. """ + assert isinstance(individual, Lineage) return self._ancestors[label].remove(individual) def add_hull(self, label, hull): @@ -408,7 +416,8 @@ def add(self, individual, label=0): """ Inserts the specified individual into this population. """ - assert individual.label == label + assert isinstance(individual, Lineage) + assert individual.head.label == label self._ancestors[label].append(individual) def __iter__(self): @@ -429,12 +438,6 @@ def iter_ancestors(self): for ancestors in self._ancestors: yield from ancestors - def find_indv(self, indv): - """ - find the index of an ancestor in population - """ - return self._ancestors[indv.label].index(indv) - class Pedigree: """ @@ -694,7 +697,7 @@ class Hull: def __init__(self, index): self.left = None self.right = None - self.lineage_head = None + self.lineage = None self.index = index self.insertion_order = math.inf @@ -984,7 +987,8 @@ def initialise(self, ts): left_end = seg.left pop = seg.population label = seg.label - self.P[seg.population].add(seg) + lineage = self.alloc_lineage(seg) + self.P[seg.population].add(lineage) while seg is not None: self.set_segment_mass(seg) seg = seg.next @@ -996,9 +1000,9 @@ def initialise(self, ts): left_end = seg.left pop = seg.population label = seg.label - lineage_head = seg + lineage = seg.lineage right_end = root_segments_tail[node].right - new_hull = self.alloc_hull(left_end, right_end, lineage_head) + new_hull = self.alloc_hull(left_end, right_end, lineage) # insert Hull floor = self.P[pop].hulls_left[label].floor_key(new_hull) insertion_order = 0 @@ -1049,15 +1053,15 @@ def change_population_growth_rate(self, pop_id, rate, time): def change_migration_matrix_element(self, pop_i, pop_j, rate): self.migration_matrix[pop_i][pop_j] = rate - def alloc_hull(self, left, right, lineage_head): - alpha = lineage_head + def alloc_hull(self, left, right, lineage): + alpha = lineage.head hull = self.hull_stack.pop() hull.left = left hull.right = right while alpha.prev is not None: alpha = alpha.prev assert alpha is not None - hull.lineage_head = alpha + hull.lineage = lineage alpha.hull = hull return hull @@ -1086,6 +1090,11 @@ def alloc_segment( s.hull = hull return s + def alloc_lineage(self, head): + lineage = Lineage(head) + head.lineage = lineage + return lineage + def copy_segment(self, segment): return self.alloc_segment( left=segment.left, @@ -1171,11 +1180,11 @@ def finalise(self): # Insert unary edges for any remainining lineages. current_time = self.t for population in self.P: - for ancestor in population.iter_ancestors(): + for lineage in population.iter_ancestors(): node = tskit.NULL # See if there is already a node in this ancestor at the # current time - seg = ancestor + seg = lineage.head while seg is not None: if self.tables.nodes[seg.node].time == current_time: node = seg.node @@ -1187,7 +1196,7 @@ def finalise(self): flags=0, time=current_time, population=population.id ) # Add in edges pointing to this ancestor - seg = ancestor + seg = lineage.head while seg is not None: if seg.node != node: self.tables.edges.add_row(seg.left, seg.right, node, seg.node) @@ -1383,12 +1392,12 @@ def single_sweep_simulate(self): # a bit ugly with the two loops because # of dealing with the pops indices = [] - for idx, u in enumerate(self.P[0].iter_label(0)): + for idx, lineage in enumerate(self.P[0].iter_label(0)): if random.random() < x: - self.set_labels(u, 1) + self.set_labels(lineage, 1) indices.append(idx) else: - assert u.label == 0 + assert lineage.head.label == 0 popped = 0 for i in indices: tmp = self.P[0].remove(i - popped, 0) @@ -1469,9 +1478,9 @@ def single_sweep_simulate(self): 0, self.sweep_site, 1.0 - x ) # clean up the labels at end - for idx, u in enumerate(self.P[0].iter_label(1)): - tmp = self.P[0].remove(idx, u.label) - self.set_labels(u, 0) + for idx, lineage in enumerate(self.P[0].iter_label(1)): + tmp = self.P[0].remove(idx, label=1) + self.set_labels(lineage, 0) self.P[0].add(tmp) def pedigree_simulate(self): @@ -1524,18 +1533,17 @@ def dtwf_generation(self): parent_nodes = [-1, -1] H = [[], []] for child in children: - segs_pair = self.dtwf_recombine(child, parent_nodes) - for seg in segs_pair: - if seg is not None and seg.index != child.index: - pop.add(seg) + lin_pair = self.dtwf_recombine(child, parent_nodes) + for lin in lin_pair: + if lin is not None and lin != child: + pop.add(lin) self.verify() # Collect segments inherited from the same individual - for i, seg in enumerate(segs_pair): - if seg is None: - continue - assert seg.prev is None - heapq.heappush(H[i], (seg.left, seg)) + for i, lin in enumerate(lin_pair): + if lin is not None: + assert lin.head.prev is None + heapq.heappush(H[i], (lin.head.left, lin.head)) # Merge segments for ploid, h in enumerate(H): @@ -1552,8 +1560,8 @@ def dtwf_generation(self): ) h = [] elif segments_to_merge >= 2: - for _, individual in h: - pop.remove_individual(individual) + for _, seg in h: + pop.remove_individual(seg.lineage) # parent_nodes[ploid] does not need to be updated here if segments_to_merge == 2: self.merge_two_ancestors( @@ -1580,9 +1588,9 @@ def process_pedigree_common_ancestors(self, ind, ploid): # All the segment chains in common_ancestors reach a common # ancestor in this ploid of this individual. First we remove # them from the populations they are stored in: - for _, anc in common_ancestors: - pop = self.P[anc.population] - pop.remove_individual(anc) + for _, seg in common_ancestors: + pop = self.P[seg.population] + pop.remove_individual(seg.lineage) # Merge together these lists of ancestral segments to create the # monoploid genome for this ploid of this individual. @@ -1600,7 +1608,7 @@ def process_pedigree_common_ancestors(self, ind, ploid): # simulation because we are *not* simulating the entire # population process, only the subset that we have information # about within the pedigree. - seg = genome + seg = genome.head while seg is not None: if seg.node != node: self.store_edge(seg.left, seg.right, parent=node, child=seg.node) @@ -1613,15 +1621,16 @@ def process_pedigree_common_ancestors(self, ind, ploid): # to create two independent lines of ancestry. parent = self.pedigree.individuals[ind.parents[ploid]] parent_ancestry = self.dtwf_recombine(genome, parent.nodes) + assert len(parent_ancestry) == ind.ploidy for parent_ploid in range(ind.ploidy): - seg = parent_ancestry[parent_ploid] - if seg is not None: + parent_lin = parent_ancestry[parent_ploid] + if parent_lin is not None: # Add this segment chain of ancestry to the accumulating # set in the parent on the corresponding ploid. - parent.add_common_ancestor(seg, ploid=parent_ploid) - if seg != genome: + parent.add_common_ancestor(parent_lin.head, ploid=parent_ploid) + if parent_lin != genome: # Add the recombined ancestor to the population - pop.add(seg) + pop.add(parent_lin) self.flush_edges() self.verify() @@ -1636,11 +1645,12 @@ def dtwf_climb_pedigree(self): # Go through the extant lineages and gather the ancestral material # into the corresponding pedigree individuals. - for anc in pop.iter_ancestors(): - node = self.tables.nodes[anc.node] + for lineage in pop.iter_ancestors(): + u = lineage.head.node + node = self.tables.nodes[u] assert node.individual != tskit.NULL ind = self.pedigree.individuals[node.individual] - ind.add_common_ancestor(anc, ploid=ind.nodes.index(anc.node)) + ind.add_common_ancestor(lineage.head, ploid=ind.nodes.index(u)) # Visit pedigree individuals in time order. visit_order = sorted(self.pedigree.individuals, key=lambda x: (x.time, x.id)) @@ -1676,10 +1686,11 @@ def migration_event(self, j, k): source = self.P[j] dest = self.P[k] index = random.randint(0, source.get_num_ancestors(label) - 1) - x = source.remove(index, label) + lineage = source.remove(index, label) + x = lineage.head hull = x.get_hull() assert (self.model == "smc_k") == (hull is not None) - dest.add(x, label) + dest.add(lineage, label) if self.model == "smc_k": source.remove_hull(label, hull) dest.add_hull(label, hull) @@ -1723,11 +1734,12 @@ def set_segment_mass(self, seg): gc_mass = self.gc_map.mass_between(gc_left_bound, seg.right) mass_index.set_value(seg.index, gc_mass) - def set_labels(self, segment, new_label): + def set_labels(self, lineage, new_label): """ - Move the specified segment to the specified label. + Move the specified lineage to the specified label. """ mass_indexes = [self.recomb_mass_index, self.gc_mass_index] + segment = lineage.head while segment is not None: masses = [] for mass_index in mass_indexes: @@ -1789,6 +1801,7 @@ def hudson_recombination_event(self, label, return_heads=False): alpha = y lhs_tail = x + right_lineage = self.alloc_lineage(alpha) if self.model == "smc_k": # modify original hull pop = alpha.population @@ -1798,11 +1811,11 @@ def hudson_recombination_event(self, label, return_heads=False): self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right) # create hull for alpha - alpha_hull = self.alloc_hull(alpha.left, rhs_right, alpha) + alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage) self.P[alpha.population].add_hull(label, alpha_hull) self.set_segment_mass(alpha) - self.P[alpha.population].add(alpha, label) + self.P[alpha.population].add(right_lineage, label) if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0: self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(lhs_tail) @@ -1814,7 +1827,8 @@ def hudson_recombination_event(self, label, return_heads=False): # Seek back to the head of the x chain while x.prev is not None: x = x.prev - ret = x, alpha + left_lineage = x.lineage + ret = left_lineage, right_lineage return ret def generate_gc_tract_length(self): @@ -1959,15 +1973,16 @@ def wiuf_gene_conversion_within_event(self, label): elif head is not None: new_individual_head = head if new_individual_head is not None: + lineage = self.alloc_lineage(new_individual_head) if self.model == "smc_k": assert hull_left < hull_right hull_right = min(self.L, hull_right + self.hull_offset) - hull = self.alloc_hull(hull_left, hull_right, new_individual_head) + hull = self.alloc_hull(hull_left, hull_right, lineage) self.P[new_individual_head.population].add_hull( new_individual_head.label, hull ) self.P[new_individual_head.population].add( - new_individual_head, new_individual_head.label + lineage, new_individual_head.label ) def wiuf_gene_conversion_left_event(self, label): @@ -1976,7 +1991,8 @@ def wiuf_gene_conversion_left_event(self, label): """ random_gc_left = random.uniform(0, self.get_total_gc_left(label)) # Get segment where gene conversion starts from left - y = self.find_cleft_individual(label, random_gc_left) + lineage = self.find_cleft_individual(label, random_gc_left) + y = lineage.head assert y is not None # generate tract_length @@ -2044,29 +2060,30 @@ def wiuf_gene_conversion_left_event(self, label): self.set_segment_mass(alpha) assert alpha.prev is None - self.P[alpha.population].add(alpha, label) + lineage = self.alloc_lineage(alpha) + self.P[alpha.population].add(lineage, label) def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq): """ Implements a recombination event in during a selective sweep. """ - lhs, rhs = self.hudson_recombination_event(label, return_heads=True) + left_lin, right_lin = self.hudson_recombination_event(label, return_heads=True) + lhs = left_lin.head + rhs = right_lin.head r = random.random() if sweep_site < rhs.left: if r < 1.0 - pop_freq: # move rhs to other population - t_idx = self.P[rhs.population].find_indv(rhs) - self.P[rhs.population].remove(t_idx, rhs.label) - self.set_labels(rhs, 1 - label) - self.P[rhs.population].add(rhs, rhs.label) + self.P[rhs.population].remove_individual(right_lin, rhs.label) + self.set_labels(right_lin, 1 - label) + self.P[rhs.population].add(right_lin, rhs.label) else: if r < 1.0 - pop_freq: # move lhs to other population - t_idx = self.P[lhs.population].find_indv(lhs) - self.P[lhs.population].remove(t_idx, lhs.label) - self.set_labels(lhs, 1 - label) - self.P[lhs.population].add(lhs, lhs.label) + self.P[rhs.population].remove_individual(left_lin, lhs.label) + self.set_labels(left_lin, 1 - label) + self.P[lhs.population].add(left_lin, lhs.label) def dtwf_generate_breakpoint(self, start): left_bound = start + 1 if self.discrete_genome else start @@ -2076,7 +2093,7 @@ def dtwf_generate_breakpoint(self, start): bp = math.floor(bp) return bp - def dtwf_recombine(self, x, ind_nodes): + def dtwf_recombine(self, lineage, ind_nodes): """ Chooses breakpoints and returns segments sorted by inheritance direction, by iterating through segment chain starting with x @@ -2084,6 +2101,7 @@ def dtwf_recombine(self, x, ind_nodes): u = self.alloc_segment(-1, -1, -1, -1, None, None) v = self.alloc_segment(-1, -1, -1, -1, None, None) seg_tails = [u, v] + x = lineage.head # TODO Should this be the recombination rate going foward from x.left? if self.recomb_map.total_mass > 0: @@ -2162,12 +2180,22 @@ def dtwf_recombine(self, x, ind_nodes): segment, ) - return u, v + ret = [] + for seg in [u, v]: + if seg is None: + ret.append(None) + else: + if seg.lineage is lineage: + ret.append(lineage) + else: + ret.append(self.alloc_lineage(seg)) + + return ret def census_event(self, time): for pop in self.P: - for ancestor in pop.iter_ancestors(): - seg = ancestor + for lineage in pop.iter_ancestors(): + seg = lineage.head u = self.tables.nodes.add_row( time=time, flags=msprime.NODE_IS_CEN_EVENT, population=pop.id ) @@ -2184,7 +2212,8 @@ def bottleneck_event(self, pop_id, label, intensity): H = [] for _ in range(pop.get_num_ancestors()): if random.random() < intensity: - x = pop.remove(0) + lineage = pop.remove(0) + x = lineage.head heapq.heappush(H, (x.left, x)) self.merge_ancestors(H, pop_id, label) @@ -2204,7 +2233,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): pass_through = len(H) == 1 alpha = None z = None - merged_head = None + new_lineage = None while len(H) > 0: alpha = None left = H[0][0] @@ -2267,8 +2296,8 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - pop.add(alpha, label) - merged_head = alpha + new_lineage = self.alloc_lineage(alpha) + pop.add(new_lineage, label) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2303,7 +2332,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): self.defrag_segment_chain(z) if coalescence: self.defrag_breakpoints() - return merged_head + return new_lineage def defrag_segment_chain(self, z): y = z @@ -2363,11 +2392,11 @@ def common_ancestor_event(self, population_index, label): hull_i_ptr, hull_j_ptr = random_pair hull_i = self.hulls[hull_i_ptr] hull_j = self.hulls[hull_j_ptr] - x = hull_i.lineage_head - y = hull_j.lineage_head - pop.remove_individual(x, label) + x_lin = hull_i.lineage + y_lin = hull_j.lineage + pop.remove_individual(x_lin, label) pop.remove_hull(label, hull_i) - pop.remove_individual(y, label) + pop.remove_individual(y_lin, label) pop.remove_hull(label, hull_j) self.free_hull(hull_i) self.free_hull(hull_j) @@ -2375,17 +2404,18 @@ def common_ancestor_event(self, population_index, label): else: # Choose two ancestors uniformly. j = random.randint(0, pop.get_num_ancestors(label) - 1) - x = pop.remove(j, label) + x_lin = pop.remove(j, label) j = random.randint(0, pop.get_num_ancestors(label) - 1) - y = pop.remove(j, label) - + y_lin = pop.remove(j, label) + x = x_lin.head + y = y_lin.head self.merge_two_ancestors(population_index, label, x, y) def merge_two_ancestors(self, population_index, label, x, y, u=-1): pop = self.P[population_index] self.num_ca_events += 1 z = None - merged_head = None + new_lineage = None coalescence = False defrag_required = False while x is not None or y is not None: @@ -2463,8 +2493,8 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - pop.add(alpha, label) - merged_head = alpha + new_lineage = self.alloc_lineage(alpha) + pop.add(new_lineage, label) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2491,9 +2521,10 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): if coalescence: self.defrag_breakpoints() - if merged_head is not None and self.model == "smc_k": + if new_lineage is not None and self.model == "smc_k": + merged_head = new_lineage.head assert merged_head.prev is None - hull = self.alloc_hull(merged_head.left, merged_head.right, merged_head) + hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage) while merged_head is not None: right = merged_head.right merged_head = merged_head.next @@ -2505,15 +2536,19 @@ def print_state(self, verify=False): for label in range(self.num_labels): print( "Recomb mass = ", - 0 - if self.recomb_mass_index is None - else self.recomb_mass_index[label].get_total(), + ( + 0 + if self.recomb_mass_index is None + else self.recomb_mass_index[label].get_total() + ), ) print( "GC mass = ", - 0 - if self.gc_mass_index is None - else self.gc_mass_index[label].get_total(), + ( + 0 + if self.gc_mass_index is None + else self.gc_mass_index[label].get_total() + ), ) print("Modifier events = ") for t, f, args in self.modifier_events: @@ -2564,7 +2599,10 @@ def print_state(self, verify=False): def verify_segments(self): for pop in self.P: for label in range(self.num_labels): - for head in pop.iter_label(label): + for lineage in pop.iter_label(label): + assert isinstance(lineage, Lineage) + head = lineage.head + assert head.lineage is lineage assert head.prev is None prev = head u = head.next @@ -2581,7 +2619,8 @@ def verify_overlaps(self): overlap_counter = OverlapCounter(self.L) for pop in self.P: for label in range(self.num_labels): - for u in pop.iter_label(label): + for lineage in pop.iter_label(label): + u = lineage.head while u is not None: overlap_counter.increment_interval(u.left, u.right) u = u.next @@ -2597,7 +2636,8 @@ def verify_overlaps(self): A[self.L] = -1 for pop in self.P: for label in range(self.num_labels): - for u in pop.iter_label(label): + for lineage in pop.iter_label(label): + u = lineage.head while u is not None: if u.left not in A: k = A.floor_key(u.left) @@ -2626,7 +2666,8 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound): total_mass = 0 alt_total_mass = 0 for pop_index, pop in enumerate(self.P): - for u in pop.iter_label(label): + for lineage in pop.iter_label(label): + u = lineage.head assert u.prev is None left = compute_left_bound(u) while u is not None: @@ -2717,8 +2758,9 @@ def verify(self): self.verify_hulls() -def make_hull(a, L, offset=0): +def make_hull(lineage, L, offset=0): hull = Hull(-1) + a = lineage.head assert a.prev is None b = a tracked_hull = a.get_hull() @@ -2729,7 +2771,7 @@ def make_hull(a, L, offset=0): hull.right = min(right + offset, L) assert tracked_hull.left == hull.left assert tracked_hull.right == hull.right - assert tracked_hull.lineage_head == a + assert tracked_hull.lineage.head == a return hull diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 0a95abd78..0abd6e579 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -432,33 +432,19 @@ def test_one_gen_pedigree(self, num_founders): ts = self.run_script(f"0 --from-ts {ts_path} -r 1 --model=fixed_pedigree") assert len(ts.dump_tables().edges) == 0 - def test_smck(self): - ts = self.run_script("10 -L 1000 -d -r 0.01 --model smc_k") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -r 0.01 --model smc_k") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -r 0.01 --model smc_k --offset 0.50") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -d -r 0.01 --model smc_k -p 2 -g 0.1") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -d -c 0.04 2 --model smc_k") - assert ts.num_trees > 1 - for tree in ts.trees(): - assert tree.num_roots == 1 - - ts = self.run_script("10 -L 1000 -c 0.04 2 --model smc_k --offset 0.75") + @pytest.mark.parametrize( + "cmd", + [ + "10 -L 1000 -d -r 0.01 --model smc_k", + "10 -L 1000 -r 0.01 --model smc_k", + "10 -L 1000 -r 0.01 --model smc_k --offset 0.50", + "10 -L 1000 -d -r 0.01 --model smc_k -p 2 -g 0.1", + "10 -L 1000 -d -c 0.04 2 --model smc_k", + "10 -L 1000 -c 0.04 2 --model smc_k --offset 0.75", + ], + ) + def test_smck(self, cmd): + ts = self.run_script(cmd) assert ts.num_trees > 1 for tree in ts.trees(): assert tree.num_roots == 1 From 217799c9ead94b8d2ebcb8b417918cb193311687 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 24 Jul 2024 17:26:26 +0100 Subject: [PATCH 2/3] Refactor C code to use basic lineage_t struct Make some progress. Disable DTWF tests for now Minimal changes to get smck working Fix DTWF simulations Fixup incorrectly failing test Make results reproducible by seed Revert back to old seed for failing test Skip segfaulting pedigree and sweep tests Fix sweeps Fix pedigree sims Keep track of lineages --- lib/msprime.c | 387 ++++++++++++++++++++++++++------------ lib/msprime.h | 8 +- lib/tests/test_ancestry.c | 5 +- lib/tests/test_sweeps.c | 3 +- 4 files changed, 276 insertions(+), 127 deletions(-) diff --git a/lib/msprime.c b/lib/msprime.c index aeceb7248..a60b0dd46 100644 --- a/lib/msprime.c +++ b/lib/msprime.c @@ -1,5 +1,5 @@ /* -** Copyright (C) 2015-2021 University of Oxford +** Copyright (C) 2015-2024 University of Oxford ** ** This file is part of msprime. ** @@ -102,9 +102,11 @@ get_population_size(population_t *pop, double t) static int cmp_individual(const void *a, const void *b) { - const segment_t *ia = (const segment_t *) a; - const segment_t *ib = (const segment_t *) b; - return (ia->id > ib->id) - (ia->id < ib->id); + const lineage_t *ia = (const lineage_t *) a; + const lineage_t *ib = (const lineage_t *) b; + /* Compare by ID of the head segment to ensure reproducibility of + * results when we use the same seed */ + return (ia->head->id > ib->head->id) - (ia->head->id < ib->head->id); } /* For the segment priority queue we want to sort on the left @@ -197,8 +199,9 @@ segment_get_hull(segment_t *seg) while (seg->prev != NULL) { seg = seg->prev; } + tsk_bug_assert(seg->lineage != NULL); hull = seg->hull; - tsk_bug_assert(hull->lineage == seg); + tsk_bug_assert(hull->lineage == seg->lineage); return hull; } @@ -466,6 +469,7 @@ msp_reindex_segments(msp_t *self) avl_node_t *node; avl_tree_t *population_ancestors; segment_t *seg; + lineage_t *lin; size_t j; label_id_t label; @@ -473,7 +477,8 @@ msp_reindex_segments(msp_t *self) for (label = 0; label < (label_id_t) self->num_labels; label++) { population_ancestors = &self->populations[j].ancestors[label]; for (node = population_ancestors->head; node != NULL; node = node->next) { - for (seg = (segment_t *) node->item; seg != NULL; seg = seg->next) { + lin = (lineage_t *) node->item; + for (seg = lin->head; seg != NULL; seg = seg->next) { msp_set_segment_mass(self, seg); } } @@ -878,6 +883,26 @@ msp_alloc_segment(msp_t *self, double left, double right, tsk_id_t value, return seg; } +static lineage_t *MSP_WARN_UNUSED +msp_alloc_lineage(msp_t *self, segment_t *head) +{ + lineage_t *lin = NULL; + + if (object_heap_empty(&self->lineage_heap)) { + if (object_heap_expand(&self->lineage_heap) != 0) { + goto out; + } + } + lin = (lineage_t *) object_heap_alloc_object(&self->lineage_heap); + if (lin == NULL) { + goto out; + } + lin->head = head; + head->lineage = lin; +out: + return lin; +} + static segment_t *MSP_WARN_UNUSED msp_copy_segment(msp_t *self, const segment_t *seg) { @@ -886,13 +911,14 @@ msp_copy_segment(msp_t *self, const segment_t *seg) } static hull_t *MSP_WARN_UNUSED -msp_alloc_hull(msp_t *self, double left, double right, segment_t *lineage) +msp_alloc_hull(msp_t *self, double left, double right, lineage_t *lineage) { hull_t *hull = NULL; label_id_t label; uint32_t j; - label = lineage->label; + tsk_bug_assert(lineage != NULL); + label = lineage->head->label; if (object_heap_empty(&self->hull_heap[label])) { if (object_heap_expand(&self->hull_heap[label]) != 0) { @@ -923,9 +949,8 @@ msp_alloc_hull(msp_t *self, double left, double right, segment_t *lineage) hull->lineage = lineage; hull->count = 0; hull->insertion_order = UINT64_MAX; - tsk_bug_assert(lineage->prev == NULL); - lineage->hull = hull; - + tsk_bug_assert(lineage->head->prev == NULL); + lineage->head->hull = hull; out: return hull; } @@ -1041,6 +1066,11 @@ msp_alloc_memory_blocks(msp_t *self) if (ret != 0) { goto out; } + ret = object_heap_init( + &self->lineage_heap, sizeof(lineage_t), self->node_mapping_block_size, NULL); + if (ret != 0) { + goto out; + } /* allocate the segments */ for (j = 0; j < self->num_labels; j++) { ret = object_heap_init(&self->segment_heap[j], sizeof(segment_t), @@ -1141,6 +1171,7 @@ msp_free(msp_t *self) /* free the object heaps */ object_heap_free(&self->avl_node_heap); object_heap_free(&self->node_mapping_heap); + object_heap_free(&self->lineage_heap); rate_map_free(&self->recomb_map); rate_map_free(&self->gc_map); if (self->model.free != NULL) { @@ -1219,6 +1250,13 @@ msp_free_hullend(msp_t *self, hullend_t *hullend, label_id_t label) object_heap_free_object(&self->hullend_heap[label], hullend); } +static void +msp_free_lineage(msp_t *self, lineage_t *lineage) +{ + object_heap_free_object(&self->lineage_heap, lineage); + lineage->head = NULL; +} + /* * Returns the segment with the specified id. */ @@ -1296,7 +1334,7 @@ msp_insert_hull(msp_t *self, hull_t *hull) /* setting hull->count requires two steps step 1: num_starting before hull->left */ tsk_bug_assert(hull != NULL); - u = hull->lineage; + u = hull->lineage->head; hulls_left = &self->populations[u->population].hulls_left[u->label]; coal_mass_index = &self->populations[u->population].coal_mass_index[u->label]; /* insert hull into state */ @@ -1371,7 +1409,7 @@ msp_remove_hull(msp_t *self, hull_t *hull) fenwick_t *coal_mass_index; segment_t *u; - u = hull->lineage; + u = hull->lineage->head; tsk_bug_assert(u != NULL); hulls_left = &self->populations[u->population].hulls_left[u->label]; coal_mass_index = &self->populations[u->population].coal_mass_index[u->label]; @@ -1420,35 +1458,37 @@ msp_remove_hull(msp_t *self, hull_t *hull) } static inline int MSP_WARN_UNUSED -msp_insert_individual(msp_t *self, segment_t *u) +msp_insert_individual(msp_t *self, lineage_t *lin) { int ret = 0; avl_node_t *node; - tsk_bug_assert(u != NULL); + tsk_bug_assert(lin != NULL); + tsk_bug_assert(lin->head != NULL); node = msp_alloc_avl_node(self); if (node == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; } - avl_init_node(node, u); - node = avl_insert_node(msp_get_segment_population(self, u), node); + avl_init_node(node, lin); + node = avl_insert_node(msp_get_segment_population(self, lin->head), node); tsk_bug_assert(node != NULL); out: return ret; } static inline void -msp_remove_individual(msp_t *self, segment_t *u) +msp_remove_individual(msp_t *self, lineage_t *lin) { avl_node_t *node; - avl_tree_t *pop = msp_get_segment_population(self, u); - - tsk_bug_assert(u != NULL); - node = avl_search(pop, u); + avl_tree_t *pop; + tsk_bug_assert(lin != NULL); + pop = msp_get_segment_population(self, lin->head); + node = avl_search(pop, lin); tsk_bug_assert(node != NULL); avl_unlink_node(pop, node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); } static void @@ -1456,7 +1496,7 @@ msp_remove_individuals_from_population(msp_t *self, avl_tree_t *Q) { avl_node_t *node; for (node = Q->head; node != NULL; node = node->next) { - msp_remove_individual(self, (segment_t *) node->item); + msp_remove_individual(self, ((segment_t *) node->item)->lineage); } } @@ -1497,8 +1537,11 @@ static void msp_print_segment_chain(msp_t *MSP_UNUSED(self), segment_t *head, FILE *out) { segment_t *s = head; + lineage_t *lin = head->lineage; + + tsk_bug_assert(lin != NULL); - fprintf(out, "[pop=%d,label=%d]", s->population, s->label); + fprintf(out, "[%p,pop=%d,label=%d]", (void *) lin, s->population, s->label); while (s != NULL) { fprintf(out, "[(%.14g,%.14g) %d] ", s->left, s->right, (int) s->value); s = s->next; @@ -1517,6 +1560,7 @@ msp_verify_segment_index( size_t j, k; const double epsilon = 1e-10; avl_node_t *node; + lineage_t *lin; segment_t *u; for (k = 0; k < self->num_labels; k++) { @@ -1525,7 +1569,8 @@ msp_verify_segment_index( for (j = 0; j < self->num_populations; j++) { node = (&self->populations[j].ancestors[k])->head; while (node != NULL) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; left = u->left; while (u != NULL) { if (u->prev != NULL) { @@ -1574,6 +1619,7 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) avl_node_t *node; segment_t *u; individual_t *ind; + lineage_t *lin; for (j = 0; j < self->input_position.nodes; j++) { for (u = self->root_segments[j]; u != NULL; u = u->next) { @@ -1589,7 +1635,9 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) for (j = 0; j < self->num_populations; j++) { node = (&self->populations[j].ancestors[k])->head; while (node != NULL) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; + tsk_bug_assert(u->lineage == lin); tsk_bug_assert(u->prev == NULL); while (u != NULL) { label_segments++; @@ -1617,6 +1665,8 @@ msp_verify_segments(msp_t *self, bool verify_breakpoints) total_avl_nodes = msp_get_num_ancestors(self) + avl_count(&self->breakpoints) + avl_count(&self->overlap_counts) + avl_count(&self->non_empty_populations); + tsk_bug_assert(msp_get_num_ancestors(self) + == object_heap_get_num_allocated(&self->lineage_heap)); if (self->model.type == MSP_MODEL_SMC_K) { for (j = 0; j < self->num_populations; j++) { for (k = 0; k < self->num_labels; k++) { @@ -1778,6 +1828,7 @@ msp_verify_overlaps(msp_t *self) avl_node_t *node; node_mapping_t *nm; sampling_event_t se; + lineage_t *lin; segment_t *u; size_t j; uint32_t label, count; @@ -1798,7 +1849,8 @@ msp_verify_overlaps(msp_t *self) for (j = 0; j < self->num_populations; j++) { for (node = (&self->populations[j].ancestors[label])->head; node != NULL; node = node->next) { - for (u = (segment_t *) node->item; u != NULL; u = u->next) { + lin = (lineage_t *) node->item; + for (u = lin->head; u != NULL; u = u->next) { overlap_counter_increment_interval(&counter, u->left, u->right); } } @@ -1875,6 +1927,7 @@ msp_verify_hulls(msp_t *self) int count, num_coalescing_pairs; avl_tree_t *avl; avl_node_t *a, *b; + lineage_t *lin; segment_t *x, *y; hull_t *hull, hull_a, hull_b; hullend_t *hullend; @@ -1894,7 +1947,8 @@ msp_verify_hulls(msp_t *self) continue; } for (a = avl->head; a->next != NULL; a = a->next) { - x = (segment_t *) a->item; + lin = (lineage_t *) a->item; + x = lin->head; hull_right = x->hull->right; hull_a.left = x->left; while (x->next != NULL) { @@ -1905,7 +1959,8 @@ msp_verify_hulls(msp_t *self) self->sequence_length); tsk_bug_assert(hull_a.right == hull_right); for (b = a->next; b != NULL; b = b->next) { - y = (segment_t *) b->item; + lin = (lineage_t *) b->item; + y = lin->head; hull_b.left = y->left; while (y->next != NULL) { y = y->next; @@ -2289,6 +2344,8 @@ msp_print_state(msp_t *self, FILE *out) object_heap_print_state(&self->avl_node_heap, out); fprintf(out, "node_mapping_heap:"); object_heap_print_state(&self->node_mapping_heap, out); + fprintf(out, "lineage_heap:"); + object_heap_print_state(&self->lineage_heap, out); fflush(out); msp_verify(self, 0); out: @@ -2511,7 +2568,8 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, population_id_t dest_pop, label_id_t dest_label) { int ret = 0; - segment_t *ind, *x, *y, *new_ind; + lineage_t *ind; + segment_t *x, *y; double recomb_mass, gc_mass; hull_t *hull, *new_hull, *h; @@ -2520,12 +2578,12 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, goto out; } - ind = (segment_t *) node->item; + ind = (lineage_t *) node->item; avl_unlink_node(source, node); msp_free_avl_node(self, node); hull = NULL; if (self->model.type == MSP_MODEL_SMC_K) { - hull = segment_get_hull(ind); + hull = segment_get_hull(ind->head); tsk_bug_assert(hull != NULL); msp_remove_hull(self, hull); } @@ -2536,16 +2594,15 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, if (ret < 0) { goto out; } - ret = msp_store_arg_edges(self, ind, TSK_NULL); + ret = msp_store_arg_edges(self, ind->head, TSK_NULL); if (ret != 0) { goto out; } } - if (ind->label == dest_label) { + if (ind->head->label == dest_label) { /* Need to set the population and label for each segment. */ - new_ind = ind; new_hull = hull; - for (x = ind; x != NULL; x = x->next) { + for (x = ind->head; x != NULL; x = x->next) { if (self->store_migrations) { ret = msp_record_migration( self, x->left, x->right, x->value, x->population, dest_pop); @@ -2558,7 +2615,6 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, } else { /* Because we are changing to a different Fenwick tree we must allocate * new segments each time. */ - new_ind = NULL; y = NULL; new_hull = NULL; tsk_bug_assert(hull == NULL); @@ -2568,11 +2624,12 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, // msp_free_hull(self, hull, ind->population, ind->label); //} h = new_hull; - for (x = ind; x != NULL; x = x->next) { + for (x = ind->head; x != NULL; x = x->next) { y = msp_alloc_segment(self, x->left, x->right, x->value, x->population, dest_label, y, NULL, h); - if (new_ind == NULL) { - new_ind = y; + if (x->prev == NULL) { + ind->head = y; + y->lineage = ind; } else { y->prev->next = y; } @@ -2591,13 +2648,13 @@ msp_move_individual(msp_t *self, avl_node_t *node, avl_tree_t *source, } } if (new_hull != NULL) { - new_hull->lineage = new_ind; + new_hull->lineage = ind; ret = msp_insert_hull(self, new_hull); if (ret != 0) { goto out; } } - ret = msp_insert_individual(self, new_ind); + ret = msp_insert_individual(self, ind); out: return ret; } @@ -2870,7 +2927,7 @@ msp_pedigree_initialise(msp_t *self) { int ret = 0; population_t *pop; - segment_t *segment; + lineage_t *lin; avl_node_t *a; label_id_t label = 0; tsk_size_t j; @@ -2898,8 +2955,8 @@ msp_pedigree_initialise(msp_t *self) for (j = 0; j < self->num_populations; j++) { pop = &self->populations[j]; for (a = pop->ancestors[label].head; a != NULL; a = a->next) { - segment = (segment_t *) a->item; - ret = msp_pedigree_add_sample_ancestry(self, segment); + lin = (lineage_t *) a->item; + ret = msp_pedigree_add_sample_ancestry(self, lin->head); if (ret != 0) { goto out; } @@ -2912,16 +2969,19 @@ msp_pedigree_initialise(msp_t *self) static int MSP_WARN_UNUSED msp_dtwf_recombine( - msp_t *self, segment_t *x, segment_t **u, segment_t **v, tsk_id_t *ind_nodes) + msp_t *self, segment_t *x_head, segment_t **u, segment_t **v, tsk_id_t *ind_nodes) { int ret = 0; int ix; + int j; double k; - segment_t *y, *z, *tail; + lineage_t *lin; + segment_t *x, *y, *z, *tail; segment_t s1, s2; segment_t *seg_tails[] = { &s1, &s2 }; segment_t **rec_heads[MSP_MAX_PED_PLOIDY] = { u, v }; + x = x_head; k = msp_dtwf_generate_breakpoint(self, x->left); s1.next = NULL; s2.next = NULL; @@ -2987,16 +3047,27 @@ msp_dtwf_recombine( x = y; } } - // Remove sentinal segments + // Remove sentinel segments *u = s1.next; *v = s2.next; + for (j = 0; j < MSP_MAX_PED_PLOIDY; j++) { + y = *rec_heads[j]; + if (y != x_head && y != NULL) { + lin = msp_alloc_lineage(self, y); + if (lin == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + } + } + if (*u != NULL && *v != NULL) { - for (int i = 0; i < MSP_MAX_PED_PLOIDY; i++) { - ret = msp_store_additional_nodes_edges(self, *rec_heads[i], ind_nodes[i], - MSP_NODE_IS_RE_EVENT, (*rec_heads[i])->population, TSK_NULL, - &ind_nodes[i]); + for (j = 0; j < MSP_MAX_PED_PLOIDY; j++) { + ret = msp_store_additional_nodes_edges(self, *rec_heads[j], ind_nodes[j], + MSP_NODE_IS_RE_EVENT, (*rec_heads[j])->population, TSK_NULL, + &ind_nodes[j]); if (ret < 0) { goto out; } @@ -3195,6 +3266,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ { int ret = 0; double breakpoint; + lineage_t *right_lineage; segment_t *x, *y, *alpha, *lhs_tail; hull_t *lhs_hull, *rhs_hull; double lhs_right, rhs_right; @@ -3243,7 +3315,12 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ } tsk_bug_assert(alpha->left < alpha->right); msp_set_segment_mass(self, alpha); - ret = msp_insert_individual(self, alpha); + right_lineage = msp_alloc_lineage(self, alpha); + if (right_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, right_lineage); if (ret != 0) { goto out; } @@ -3258,7 +3335,7 @@ msp_recombination_event(msp_t *self, label_id_t label, segment_t **lhs, segment_ self, lhs_hull, rhs_right, lhs_right, lhs_tail->population, label); /* create new hull for alpha */ - rhs_hull = msp_alloc_hull(self, alpha->left, rhs_right, alpha); + rhs_hull = msp_alloc_hull(self, alpha->left, rhs_right, alpha->lineage); if (rhs_hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3315,6 +3392,7 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) { int ret = 0; segment_t *x, *y, *alpha, *head, *tail, *z, *new_individual_head; + lineage_t *new_lineage; double left_breakpoint, right_breakpoint, tl; bool insert_alpha; hull_t *hull = NULL; @@ -3499,13 +3577,18 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) new_individual_head = head; } if (new_individual_head != NULL) { + new_lineage = msp_alloc_lineage(self, new_individual_head); + if (new_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, new_lineage); if (self->model.type == MSP_MODEL_SMC_K) { tsk_bug_assert(tract_hull_left < tract_hull_right); tract_hull_right = GSL_MIN( tract_hull_right + self->model.params.smc_k_coalescent.hull_offset, self->sequence_length); - hull = msp_alloc_hull( - self, tract_hull_left, tract_hull_right, new_individual_head); + hull = msp_alloc_hull(self, tract_hull_left, tract_hull_right, new_lineage); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3515,7 +3598,6 @@ msp_gene_conversion_event(msp_t *self, label_id_t label) goto out; } } - ret = msp_insert_individual(self, new_individual_head); } else { self->num_noneffective_gc_events++; } @@ -3576,6 +3658,7 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l double l, r, l_min, r_max; avl_node_t *node; node_mapping_t *nm, search; + lineage_t *new_lineage; segment_t *x, *y, *z, *alpha, *beta, *merged_head; hull_t *hull = NULL; @@ -3704,7 +3787,12 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l } if (alpha != NULL) { if (z == NULL) { - ret = msp_insert_individual(self, alpha); + new_lineage = msp_alloc_lineage(self, alpha); + if (new_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, new_lineage); if (ret != 0) { goto out; } @@ -3765,8 +3853,8 @@ msp_merge_two_ancestors(msp_t *self, population_id_t population_id, label_id_t l y = y->next; } r += self->model.params.smc_k_coalescent.hull_offset; - hull = msp_alloc_hull( - self, merged_head->left, GSL_MIN(r, self->sequence_length), merged_head); + hull = msp_alloc_hull(self, merged_head->left, + GSL_MIN(r, self->sequence_length), merged_head->lineage); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -3831,6 +3919,7 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, segment_t *x, *z, *alpha; segment_t **H = NULL; segment_t *merged_head = NULL; + lineage_t *new_lineage = NULL; tsk_id_t individual = TSK_NULL; H = malloc(avl_count(Q) * sizeof(segment_t *)); @@ -3965,7 +4054,12 @@ msp_merge_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, if (alpha != NULL) { if (z == NULL) { merged_head = alpha; - ret = msp_insert_individual(self, alpha); + new_lineage = msp_alloc_lineage(self, alpha); + if (new_lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, new_lineage); if (ret != 0) { goto out; } @@ -4039,9 +4133,10 @@ msp_merge_n_ancestors(msp_t *self, avl_tree_t *Q, population_id_t population_id, /* Migrate any of the child segments to this population, if necessary */ for (a = Q->head; a != NULL; a = a->next) { u = (segment_t *) a->item; + tsk_bug_assert(u->lineage != NULL); if (u->population != population_id) { current_pop = &self->populations[u->population]; - avl_node = avl_search(¤t_pop->ancestors[label], u); + avl_node = avl_search(¤t_pop->ancestors[label], u->lineage); tsk_bug_assert(avl_node != NULL); ret = msp_move_individual( self, avl_node, ¤t_pop->ancestors[label], population_id, label); @@ -4107,6 +4202,7 @@ msp_reset_memory_state(msp_t *self) avl_node_t *node; node_mapping_t *nm; population_t *pop; + lineage_t *lin; segment_t *u, *v; hull_t *x; hullend_t *y; @@ -4117,7 +4213,8 @@ msp_reset_memory_state(msp_t *self) pop = &self->populations[j]; for (label = 0; label < (label_id_t) self->num_labels; label++) { for (node = pop->ancestors[label].head; node != NULL; node = node->next) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; while (u != NULL) { v = u->next; msp_free_segment(self, u); @@ -4125,6 +4222,7 @@ msp_reset_memory_state(msp_t *self) } avl_unlink_node(&pop->ancestors[label], node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); } if (pop->hulls_left != NULL) { for (node = pop->hulls_left[label].head; node != NULL; @@ -4165,6 +4263,7 @@ static int msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_head) { int ret = 0; + lineage_t *lineage; segment_t *copy, *prev; const segment_t *seg; double breakpoints[2]; @@ -4196,14 +4295,19 @@ msp_insert_root_segments(msp_t *self, const segment_t *head, segment_t **new_hea } copy->prev = prev; if (prev == NULL) { - ret = msp_insert_individual(self, copy); + lineage = msp_alloc_lineage(self, copy); + if (lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + ret = msp_insert_individual(self, lineage); if (ret != 0) { goto out; } if (self->model.type == MSP_MODEL_SMC_K) { if (self->state != MSP_STATE_NEW) { /* correct hull->right is set at the end */ - hull = msp_alloc_hull(self, head->left, copy->right, copy); + hull = msp_alloc_hull(self, head->left, copy->right, lineage); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4575,7 +4679,8 @@ msp_initialise_smc_k(msp_t *self) avl_node_t *h_node, *a_node; hull_t *hull; double left, right; - segment_t *seg, *head; + lineage_t *lin; + segment_t *seg; for (population_id = 0; population_id < (population_id_t) self->num_populations; population_id++) { @@ -4584,10 +4689,10 @@ msp_initialise_smc_k(msp_t *self) hulls_left = &self->populations[population_id].hulls_left[label_id]; for (a_node = population_ancestors->head; a_node != NULL; a_node = a_node->next) { - seg = (segment_t *) a_node->item; + lin = (lineage_t *) a_node->item; + seg = lin->head; tsk_bug_assert(seg->prev == NULL); left = seg->left; - head = seg; while (seg != NULL) { right = seg->right; seg = seg->next; @@ -4595,7 +4700,7 @@ msp_initialise_smc_k(msp_t *self) /* insert into hulls_left */ right += self->model.params.smc_k_coalescent.hull_offset; right = GSL_MIN(right, self->sequence_length); - hull = msp_alloc_hull(self, left, right, head); + hull = msp_alloc_hull(self, left, right, lin); if (hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -4962,13 +5067,13 @@ msp_get_total_gc_left(msp_t *self) return total; } -static segment_t * +static lineage_t * msp_find_gc_left_individual(msp_t *self, label_id_t label, double value) { size_t j, num_ancestors, individual_index; avl_tree_t *ancestors; avl_node_t *node; - segment_t *ind; + lineage_t *ind; double mean_gc_rate = rate_map_get_total_mass(&self->gc_map) / self->sequence_length; individual_index = (size_t) floor(value / (mean_gc_rate * self->gc_tract_length)); @@ -4979,7 +5084,7 @@ msp_find_gc_left_individual(msp_t *self, label_id_t label, double value) /* Choose the correct individual */ node = avl_at(ancestors, (unsigned int) individual_index); assert(node != NULL); - ind = (segment_t *) node->item; + ind = (lineage_t *) node->item; return ind; } else { individual_index -= num_ancestors; @@ -5023,12 +5128,15 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) const double gc_left_total = msp_get_total_gc_left(self); double h = gsl_rng_uniform(self->rng) * gc_left_total; double tl, bp, lhs_old_right, lhs_new_right; + lineage_t *lineage; segment_t *y, *x, *alpha; hull_t *rhs_hull; hull_t *lhs_hull = NULL; lhs_hull = NULL; - y = msp_find_gc_left_individual(self, label, h); + lineage = msp_find_gc_left_individual(self, label, h); + assert(lineage != NULL); + y = lineage->head; assert(y != NULL); /* generate tract length */ @@ -5107,6 +5215,22 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) } lhs_new_right = y->right; + lineage = msp_alloc_lineage(self, alpha); + if (lineage == NULL) { + ret = MSP_ERR_NO_MEMORY; + goto out; + } + msp_set_segment_mass(self, alpha); + tsk_bug_assert(alpha->prev == NULL); + + ret = msp_insert_individual(self, lineage); + if (self->additional_nodes & MSP_NODE_IS_GC_EVENT) { + ret = msp_store_arg_gene_conversion(self, NULL, y, alpha); + if (ret != 0) { + goto out; + } + } + if (self->model.type == MSP_MODEL_SMC_K) { // lhs logic is identical to the lhs recombination event lhs_old_right = lhs_hull->right; @@ -5118,7 +5242,7 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) // rhs tsk_bug_assert(alpha->left < lhs_old_right); - rhs_hull = msp_alloc_hull(self, alpha->left, lhs_old_right, alpha); + rhs_hull = msp_alloc_hull(self, alpha->left, lhs_old_right, alpha->lineage); if (rhs_hull == NULL) { ret = MSP_ERR_NO_MEMORY; goto out; @@ -5129,15 +5253,6 @@ msp_gene_conversion_left_event(msp_t *self, label_id_t label) } } - msp_set_segment_mass(self, alpha); - tsk_bug_assert(alpha->prev == NULL); - ret = msp_insert_individual(self, alpha); - if (self->additional_nodes & MSP_NODE_IS_GC_EVENT) { - ret = msp_store_arg_gene_conversion(self, NULL, y, alpha); - if (ret != 0) { - goto out; - } - } out: return ret; } @@ -5336,6 +5451,7 @@ msp_pedigree_process_common_ancestors(msp_t *self, individual_t *ind, tsk_size_t goto out; } if (genome != NULL) { + tsk_bug_assert(genome->lineage != NULL); tsk_bug_assert(genome->prev == NULL); if (parent == TSK_NULL) { @@ -5376,6 +5492,8 @@ msp_pedigree_process_common_ancestors(msp_t *self, individual_t *ind, tsk_size_t for (j = 0; j < ploidy; j++) { seg = parent_ancestry[j]; if (seg != NULL) { + tsk_bug_assert(seg->lineage != NULL); + tsk_bug_assert(seg->lineage->head == seg); tsk_bug_assert(seg->prev == NULL); ret = msp_pedigree_add_individual_common_ancestor( self, parent, seg, j); @@ -5383,7 +5501,7 @@ msp_pedigree_process_common_ancestors(msp_t *self, individual_t *ind, tsk_size_t goto out; } if (seg != genome) { - ret = msp_insert_individual(self, seg); + ret = msp_insert_individual(self, seg->lineage); if (ret != 0) { goto out; } @@ -5486,6 +5604,7 @@ msp_dtwf_generation(msp_t *self) segment_list_t **parents = NULL; segment_list_t *segment_mem = NULL; segment_list_t *s; + lineage_t *lin; avl_node_t *a, *node; avl_tree_t Q[2]; /* Only support single structured coalescent label for now. */ @@ -5542,7 +5661,8 @@ msp_dtwf_generation(msp_t *self) } for (s = parents[k]; s != NULL; s = s->next) { node = s->node; - x = (segment_t *) node->item; + lin = (lineage_t *) node->item; + x = lin->head; // Recombine ancestor // TODO Should this be the recombination rate going foward from x.left? if (rate_map_get_total_mass(&self->recomb_map) > 0) { @@ -5552,7 +5672,7 @@ msp_dtwf_generation(msp_t *self) } for (i = 0; i < 2; i++) { if (u[i] != NULL && u[i] != x) { - ret = msp_insert_individual(self, u[i]); + ret = msp_insert_individual(self, u[i]->lineage); if (ret != 0) { goto out; } @@ -5889,7 +6009,7 @@ msp_change_label(msp_t *self, segment_t *ind, label_id_t label) avl_node_t *node; /* Find the this individual in the AVL tree. */ - node = avl_search(pop, ind); + node = avl_search(pop, ind->lineage); tsk_bug_assert(node != NULL); ret = msp_move_individual(self, node, pop, ind->population, label); return ret; @@ -5908,6 +6028,9 @@ msp_sweep_recombination_event( if (ret != 0) { goto out; } + tsk_bug_assert(lhs->lineage != NULL); + tsk_bug_assert(rhs->lineage != NULL); + /* NOTE: we can look at rhs->left when we compare to the sweep site. */ r = gsl_rng_uniform(self->rng); if (sweep_locus < rhs->left) { @@ -5987,6 +6110,7 @@ msp_run_sweep(msp_t *self) if (ret != 0) { goto out; } + msp_verify(self, 0); ret = msp_sweep_initialise(self, allele_frequency[0]); if (ret != 0) { goto out; @@ -5995,6 +6119,7 @@ msp_run_sweep(msp_t *self) curr_step = 1; while (msp_get_num_ancestors(self) > 0 && curr_step < num_steps) { events++; + msp_verify(self, 0); /* Set pop sizes & rec_rates */ for (j = 0; j < self->num_labels; j++) { label = (label_id_t) j; @@ -6040,6 +6165,7 @@ msp_run_sweep(msp_t *self) rec_rates[1], self->ploidy); printf("event_prob: %g rand: %g\n", event_prob, event_rand); */ + event_prob *= 1.0 - total_rate; curr_step++; @@ -6057,7 +6183,6 @@ msp_run_sweep(msp_t *self) t_unscaled = time[curr_step - 1] * self->ploidy * pop_size; tsk_bug_assert(t_unscaled > 0); self->time = t_start + t_unscaled; - /* printf("event time: %g\n", self->time); */ if (tmp_rand < e_sum / sweep_pop_tot_rate) { /* coalescent in b background */ ret = self->common_ancestor_event(self, 0, 0); @@ -6082,7 +6207,6 @@ msp_run_sweep(msp_t *self) if (ret != 0) { goto out; } - /* msp_print_state(self, stdout); */ } /* TODO we should probably support fixed events here using @@ -6177,6 +6301,7 @@ msp_insert_uncoalesced_edges(msp_t *self) label_id_t label; avl_node_t *a; segment_t *seg; + lineage_t *lin; tsk_id_t node; int64_t edge_start; tsk_node_table_t *nodes = &self->tables->nodes; @@ -6194,7 +6319,8 @@ msp_insert_uncoalesced_edges(msp_t *self) * could only have arisen as the result of a coalescence and so this * node really does represent the current ancestor */ node = TSK_NULL; - for (seg = (segment_t *) a->item; seg != NULL; seg = seg->next) { + lin = (lineage_t *) a->item; + for (seg = lin->head; seg != NULL; seg = seg->next) { if (nodes->time[seg->value] == current_time) { node = seg->value; break; @@ -6211,7 +6337,7 @@ msp_insert_uncoalesced_edges(msp_t *self) } /* For every segment add an edge pointing to this new node */ - for (seg = (segment_t *) a->item; seg != NULL; seg = seg->next) { + for (seg = lin->head; seg != NULL; seg = seg->next) { if (seg->value != node) { tsk_bug_assert(nodes->time[node] > nodes->time[seg->value]); ret = tsk_edge_table_add_row(&self->tables->edges, seg->left, @@ -6475,6 +6601,7 @@ msp_get_ancestors(msp_t *self, segment_t **ancestors) int ret = -1; avl_node_t *node; avl_tree_t *population_ancestors; + lineage_t *lineage; size_t j; label_id_t label; size_t k = 0; @@ -6483,7 +6610,8 @@ msp_get_ancestors(msp_t *self, segment_t **ancestors) for (label = 0; label < (label_id_t) self->num_labels; label++) { population_ancestors = &self->populations[j].ancestors[label]; for (node = population_ancestors->head; node != NULL; node = node->next) { - ancestors[k] = (segment_t *) node->item; + lineage = (lineage_t *) node->item; + ancestors[k] = lineage->head; k++; } } @@ -7200,6 +7328,7 @@ msp_simple_bottleneck(msp_t *self, demographic_event_t *event) population_id_t N = (population_id_t) self->num_populations; avl_node_t *node, *next, *q_node; avl_tree_t *pop, Q; + lineage_t *lin; segment_t *u; label_id_t label = 0; /* For now only support label 0 */ @@ -7222,9 +7351,11 @@ msp_simple_bottleneck(msp_t *self, demographic_event_t *event) while (node != NULL) { next = node->next; if (gsl_rng_uniform(self->rng) < p) { - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; avl_unlink_node(pop, node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); q_node = msp_alloc_avl_node(self); if (q_node == NULL) { ret = MSP_ERR_NO_MEMORY; @@ -7299,7 +7430,7 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) double rate, t; avl_tree_t *pop; avl_node_t *node, *set_node; - segment_t *individual; + lineage_t *lin; label_id_t label = 0; /* For now only support label 0 */ if (self->model.type == MSP_MODEL_DTWF) { @@ -7373,7 +7504,7 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) if (u >= (tsk_id_t) n) { /* Remove this node from the population, and add it into the * set for the root at u */ - individual = (segment_t *) avl_nodes[j]->item; + lin = (lineage_t *) avl_nodes[j]->item; avl_unlink_node(pop, avl_nodes[j]); msp_free_avl_node(self, avl_nodes[j]); set_node = msp_alloc_avl_node(self); @@ -7381,7 +7512,8 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) ret = MSP_ERR_NO_MEMORY; goto out; } - avl_init_node(set_node, individual); + avl_init_node(set_node, lin->head); + msp_free_lineage(self, lin); set_node = avl_insert_node(&sets[u], set_node); tsk_bug_assert(set_node != NULL); } @@ -7396,18 +7528,10 @@ msp_instantaneous_bottleneck(msp_t *self, demographic_event_t *event) } } out: - if (lineages != NULL) { - free(lineages); - } - if (pi != NULL) { - free(pi); - } - if (sets != NULL) { - free(sets); - } - if (avl_nodes != NULL) { - free(avl_nodes); - } + msp_safe_free(lineages); + msp_safe_free(pi); + msp_safe_free(sets); + msp_safe_free(avl_nodes); return ret; } @@ -7459,6 +7583,7 @@ msp_census_event(msp_t *self, demographic_event_t *event) avl_tree_t *ancestors; avl_node_t *node; segment_t *seg; + lineage_t *lin; tsk_id_t i, j; tsk_id_t u; @@ -7470,8 +7595,8 @@ msp_census_event(msp_t *self, demographic_event_t *event) node = ancestors->head; while (node != NULL) { - seg = (segment_t *) node->item; - + lin = (lineage_t *) node->item; + seg = lin->head; while (seg != NULL) { // Add an edge to the edge table. ret = tsk_node_table_add_row(&self->tables->nodes, @@ -7635,6 +7760,7 @@ msp_std_common_ancestor_event( uint32_t j, n; avl_tree_t *ancestors; avl_node_t *x_node, *y_node, *node; + lineage_t *x_lin, *y_lin; segment_t *x, *y; ancestors = &self->populations[population_id].ancestors[label]; @@ -7643,12 +7769,14 @@ msp_std_common_ancestor_event( j = (uint32_t) gsl_rng_uniform_int(self->rng, n); x_node = avl_at(ancestors, j); tsk_bug_assert(x_node != NULL); - x = (segment_t *) x_node->item; + x_lin = (lineage_t *) x_node->item; + x = x_lin->head; avl_unlink_node(ancestors, x_node); j = (uint32_t) gsl_rng_uniform_int(self->rng, n - 1); y_node = avl_at(ancestors, j); tsk_bug_assert(y_node != NULL); - y = (segment_t *) y_node->item; + y_lin = (lineage_t *) y_node->item; + y = y_lin->head; avl_unlink_node(ancestors, y_node); /* For SMC and SMC' models we reject some events to get the required @@ -7656,16 +7784,18 @@ msp_std_common_ancestor_event( if (msp_reject_ca_event(self, x, y)) { self->num_rejected_ca_events++; /* insert x and y back into the population */ - tsk_bug_assert(x_node->item == x); + tsk_bug_assert(x_node->item == x_lin); node = avl_insert_node(ancestors, x_node); tsk_bug_assert(node != NULL); - tsk_bug_assert(y_node->item == y); + tsk_bug_assert(y_node->item == y_lin); node = avl_insert_node(ancestors, y_node); tsk_bug_assert(node != NULL); } else { self->num_ca_events++; msp_free_avl_node(self, x_node); + msp_free_lineage(self, x_lin); msp_free_avl_node(self, y_node); + msp_free_lineage(self, y_lin); ret = msp_merge_two_ancestors(self, population_id, label, x, y, TSK_NULL, NULL); } return ret; @@ -7711,6 +7841,7 @@ msp_smc_k_common_ancestor_event( avl_tree_t *avl; avl_node_t *x_node, *y_node, *search; hull_t *x_hull, *y_hull = NULL; + lineage_t *x_lin, *y_lin; segment_t *x, *y; /* find first hull */ @@ -7741,14 +7872,15 @@ msp_smc_k_common_ancestor_event( /* retrieve ancestors linked to both hulls */ avl = &self->populations[population_id].ancestors[label]; - x = (segment_t *) x_hull->lineage; - x_node = avl_search(avl, x); + x_lin = x_hull->lineage; + x = x_lin->head; + x_node = avl_search(avl, x_lin); tsk_bug_assert(x_node != NULL); avl_unlink_node(avl, x_node); - y = (segment_t *) y_hull->lineage; - y_node = avl_search(avl, y); + y_lin = y_hull->lineage; + y = y_lin->head; + y_node = avl_search(avl, y_lin); tsk_bug_assert(y_node != NULL); - y = (segment_t *) y_node->item; avl_unlink_node(avl, y_node); self->num_ca_events++; @@ -7756,8 +7888,9 @@ msp_smc_k_common_ancestor_event( msp_free_hull(self, y_hull, population_id, label); msp_free_avl_node(self, x_node); msp_free_avl_node(self, y_node); + msp_free_lineage(self, x_lin); + msp_free_lineage(self, y_lin); ret = msp_merge_two_ancestors(self, population_id, label, x, y, TSK_NULL, NULL); - return ret; } @@ -7836,6 +7969,7 @@ msp_dirac_common_ancestor_event(msp_t *self, population_id_t pop_id, label_id_t avl_tree_t *ancestors, Q[4]; /* MSVC won't let us use num_pots here */ avl_node_t *x_node, *y_node; segment_t *x, *y; + lineage_t *x_lin, *y_lin; double nC2, p; double psi = self->model.params.dirac_coalescent.psi; @@ -7865,16 +7999,20 @@ msp_dirac_common_ancestor_event(msp_t *self, population_id_t pop_id, label_id_t j = (uint32_t) gsl_rng_uniform_int(self->rng, n); x_node = avl_at(ancestors, j); tsk_bug_assert(x_node != NULL); - x = (segment_t *) x_node->item; + x_lin = (lineage_t *) x_node->item; + x = x_lin->head; avl_unlink_node(ancestors, x_node); j = (uint32_t) gsl_rng_uniform_int(self->rng, n - 1); y_node = avl_at(ancestors, j); tsk_bug_assert(y_node != NULL); - y = (segment_t *) y_node->item; + y_lin = (lineage_t *) y_node->item; + y = y_lin->head; avl_unlink_node(ancestors, y_node); self->num_ca_events++; msp_free_avl_node(self, x_node); + msp_free_lineage(self, x_lin); msp_free_avl_node(self, y_node); + msp_free_lineage(self, y_lin); ret = msp_merge_two_ancestors(self, pop_id, label, x, y, TSK_NULL, NULL); } } else { @@ -8003,6 +8141,7 @@ msp_multi_merger_common_ancestor_event( uint32_t j, i, l; avl_node_t *node, *q_node; segment_t *u; + lineage_t *lin; uint32_t pot_size; uint32_t cumul_pot_size = 0; @@ -8019,9 +8158,11 @@ msp_multi_merger_common_ancestor_event( node = avl_at(ancestors, j); tsk_bug_assert(node != NULL); - u = (segment_t *) node->item; + lin = (lineage_t *) node->item; + u = lin->head; avl_unlink_node(ancestors, node); msp_free_avl_node(self, node); + msp_free_lineage(self, lin); q_node = msp_alloc_avl_node(self); if (q_node == NULL) { diff --git a/lib/msprime.h b/lib/msprime.h index dc42a40f4..f27a22c6f 100644 --- a/lib/msprime.h +++ b/lib/msprime.h @@ -85,8 +85,13 @@ typedef struct segment_t_t { struct segment_t_t *prev; struct segment_t_t *next; struct hull_t_t *hull; + struct lineage_t_t *lineage; } segment_t; +typedef struct lineage_t_t { + segment_t *head; +} lineage_t; + typedef struct { double position; uint32_t value; @@ -95,7 +100,7 @@ typedef struct { typedef struct hull_t_t { double left; double right; - segment_t *lineage; + lineage_t *lineage; size_t id; uint64_t count; uint64_t insertion_order; @@ -277,6 +282,7 @@ typedef struct _msp_t { /* memory management */ object_heap_t avl_node_heap; object_heap_t node_mapping_heap; + object_heap_t lineage_heap; /* We keep an independent segment heap for each label */ object_heap_t *segment_heap; /* We keep an independent hull heap for each label */ diff --git a/lib/tests/test_ancestry.c b/lib/tests/test_ancestry.c index 1c4380256..e165d0bc9 100644 --- a/lib/tests/test_ancestry.c +++ b/lib/tests/test_ancestry.c @@ -1,5 +1,5 @@ /* -** Copyright (C) 2016-2021 University of Oxford +** Copyright (C) 2016-2024 University of Oxford ** ** This file is part of msprime. ** @@ -1645,6 +1645,7 @@ test_multiple_mergers_unary_nodes(void) CU_ASSERT_EQUAL(ret, 0); msp_verify(&msp, 0); + /* msp_print_state(&msp, stdout); */ CU_ASSERT_TRUE(msp_get_num_breakpoints(&msp) > 0); // verify whether there is at least one unary node num_edges = tables.edges.num_rows; @@ -4307,7 +4308,7 @@ main(int argc, char **argv) { "test_multiple_mergers_growth_rate", test_multiple_mergers_growth_rate }, { "test_dirac_coalescent_bad_parameters", test_dirac_coalescent_bad_parameters }, { "test_beta_coalescent_bad_parameters", test_beta_coalescent_bad_parameters }, - { "test_multipe_mergers_unary_nodes", test_multiple_mergers_unary_nodes }, + { "test_multiple_mergers_unary_nodes", test_multiple_mergers_unary_nodes }, { "test_simulator_getters_setters", test_simulator_getters_setters }, { "test_demographic_events", test_demographic_events }, diff --git a/lib/tests/test_sweeps.c b/lib/tests/test_sweeps.c index aa3f149d2..93f107750 100644 --- a/lib/tests/test_sweeps.c +++ b/lib/tests/test_sweeps.c @@ -374,8 +374,9 @@ static void test_sweep_genic_selection_mimic_msms(void) { /* To mimic the nrepeats = 300 parameter in msms cmdline arguments*/ - for (int i = 0; i < 300; i++) + for (int i = 0; i < 300; i++) { sweep_genic_selection_mimic_msms_single_run(i + 1); + } } int From 8ee8cb11ff0b69d60230117563dc945aecb0e786 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 25 Jul 2024 12:01:55 +0100 Subject: [PATCH 3/3] Fix warning suppressions on provenance tests Newer versions of jsonschema emit a DeprecationWarning --- tests/test_provenance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 805235155..5eba25b85 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -80,7 +80,7 @@ class TestBuildObjects: def decode(self, prov): # Supress warnings about schemas here - it's no big deal and # not easy to fix - with pytest.warns(UserWarning): + with pytest.warns((UserWarning, DeprecationWarning)): builder = pjs.ObjectBuilder(tskit.provenance.get_schema()) ns = builder.build_classes() return ns.TskitProvenance.from_json(prov)