From 97f7f914d874303360612f2a9976870956ee334d Mon Sep 17 00:00:00 2001 From: Hannes Becher Date: Tue, 8 Oct 2024 14:50:10 +0100 Subject: [PATCH] N function --- msprime/pedigrees.py | 55 +++++++++++++++++++++++++++++++++++------- tests/test_pedigree.py | 36 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 9 deletions(-) diff --git a/msprime/pedigrees.py b/msprime/pedigrees.py index 09c5399ab..23f5a1934 100644 --- a/msprime/pedigrees.py +++ b/msprime/pedigrees.py @@ -402,7 +402,7 @@ def sim_pedigree_backward( builder, rng, *, - population_size, + N_function, num_samples, end_time, ): @@ -412,7 +412,9 @@ def sim_pedigree_backward( # Because we don't have overlapping generations we can add the ancestors # to the pedigree once the previous generation has been produced for time in range(0, end_time): - population = np.arange(population_size, dtype=np.int32) + # get current population size from function, add 1 to time to match behaviour of + # sim_pedigree_forward + population = np.arange(N_function(time + 1), dtype=np.int32) parents = rng.choice(population, (len(ancestors), 2)) unique_parents = np.unique(parents) parent_ids = np.searchsorted(unique_parents, parents).astype(np.int32) @@ -432,7 +434,7 @@ def sim_pedigree_forward( builder, rng, *, - population_size, + N_function, end_time, ): population = np.array([tskit.NULL], dtype=np.int32) @@ -440,7 +442,7 @@ def sim_pedigree_forward( # To make the semantics compatible with dtwf, the end_time means the # *end* of generation end_time for time in reversed(range(end_time + 1)): - N = population_size # This could be derived from the Demography + N = N_function(time) # This could be derived from the Demography # NB this is *with* replacement, so 1 / N chance of selfing parents = rng.choice(population, (N, 2)) population = builder.add_individuals(parents=parents, time=time) @@ -456,29 +458,64 @@ def sim_pedigree( end_time=None, direction="forward", ): + """ + Simulate a pedigree with the specified parameters. + + :param int|function population_size: The population size at each + generation. This can be a single value, or a function that takes + the generation number as input and returns the population size. + If a function is provided, the function is called with the + generation number 0 to determine the number of samples. + :param int|None num_samples: The number of samples (at generation 0). + If None, the number of samples is equal to the population size at + generation 0. + :param float|None sequence_length: The sequence length of the resulting + TableCollection. + :param int|None random_seed: The random seed to use. If None, the random + seed is not set. + :param int end_time: The end time of the pedigree (in generations back + in time). + :param str direction: The direction of time in the pedigree. Either + 'forward' or 'backward'. + + :return: The TableCollection containing the pedigree data. + :rtype: tskit.TableCollection + """ + # Internal utility for generating pedigree data. This function is not # part of the public API and subject to arbitrary changes/removal # in the future. - num_samples = population_size if num_samples is None else num_samples + + # allow for population_size to be a single value or a function + if not callable(population_size): + + def N_function(_): + return population_size + + else: + N_function = population_size + + num_samples = N_function(0) if num_samples is None else num_samples builder = PedigreeBuilder() rng = np.random.RandomState(random_seed) if direction == "forward": - if num_samples != population_size: + if num_samples != N_function(0): raise ValueError( - "num_samples must be equal to population_size for forward simulation" + "if at all specified, num_samples must be equal to population_size " + "at generation 0 for forward simulation" ) tables = sim_pedigree_forward( builder, rng, - population_size=population_size, + N_function=N_function, end_time=end_time, ) elif direction == "backward": tables = sim_pedigree_backward( builder, rng, - population_size=population_size, + N_function=N_function, num_samples=num_samples, end_time=end_time, ) diff --git a/tests/test_pedigree.py b/tests/test_pedigree.py index 765ec41b1..a9c4efa25 100644 --- a/tests/test_pedigree.py +++ b/tests/test_pedigree.py @@ -373,6 +373,42 @@ def test_valid_pedigree(self): tc.individuals.append(row) tc.tree_sequence() # creating tree sequence should succeed + @pytest.mark.parametrize("direction", ["backward", "forward"]) + def test_equal_N_Nfunc(self, direction): + seed = np.random.randint(1e6) + p1 = pedigrees.sim_pedigree( + num_samples=10, + population_size=10, + end_time=20, + random_seed=seed, + direction=direction, + ) + p2 = pedigrees.sim_pedigree( + num_samples=10, + population_size=lambda _: 10, + end_time=20, + random_seed=seed, + direction=direction, + ) + assert p1 == p2 + + @pytest.mark.parametrize("direction", ["backward", "forward"]) + def test_pop_collapse(self, direction): + seed = np.random.randint(1e6) + collapse_time = 10 + ped = pedigrees.sim_pedigree( + num_samples=100, + population_size=lambda t: 100 if t < collapse_time else 1, + end_time=20, + random_seed=seed, + direction=direction, + ) + times_old_nodes = ped.nodes.time[ped.nodes.time > collapse_time - 1] + _, node_counts = np.unique(times_old_nodes, return_counts=True) + # if we have a collapse down to one (diploid) individual, there should be only + # two nodes at each time point after the collapse + assert all(node_counts == 2) + def join_pedigrees(tables_list): """