diff --git a/tests/test_functions.py b/tests/test_functions.py index 349cef70..993038fc 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -33,7 +33,7 @@ import tsinfer import tsdate -from tsdate.date import (SpansBySamples, PriorParams, +from tsdate.date import (SpansBySamples, PriorParams, LIN, LOG, ConditionalCoalescentTimes, fill_prior, Likelihoods, LogLikelihoods, LogLikelihoodsStreaming, InOutAlgorithms, NodeGridValues, gamma_approx, constrain_ages_topo) # NOQA @@ -189,12 +189,19 @@ def test_larger_find_node_tip_weights(self): self.verify_weights(ts) def test_dangling_nodes_warn(self): - ts = utility_functions.single_tree_ts_n3_dangling() + ts = utility_functions.single_tree_ts_n2_dangling() with self.assertLogs(level="WARNING") as log: self.verify_weights(ts) self.assertGreater(len(log.output), 0) self.assertIn("dangling", log.output[0]) + def test_simple_non_contemporaneous(self): + ts = utility_functions.two_tree_ts_n3_non_contemporaneous() + n = len([s for s in ts.samples() if ts.node(s).time == 0]) + span_data = self.verify_weights(ts) + self.assertEqual(span_data.lookup_weight(4, n, 2), 0.2) # 2 contemporanous tips + self.assertEqual(span_data.lookup_weight(4, n, 1), 0.8) # only 1 contemporanous + @unittest.skip("YAN to fix") def test_truncated_nodes(self): Ne = 1e2 @@ -337,9 +344,10 @@ class TestMixturePrior(unittest.TestCase): def get_mixture_prior_params(self, ts, prior_distr): span_data = SpansBySamples(ts) priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr) - priors.add(ts.num_samples, approximate=False) + for total_fixed in span_data.total_fixed_at_0_counts: + priors.add(total_fixed, approximate=False) mixture_prior = priors.get_mixture_prior_params(span_data) - return(mixture_prior) + return mixture_prior def test_one_tree_n2(self): ts = utility_functions.single_tree_ts_n2() @@ -420,12 +428,30 @@ def test_two_tree_mutation_ts(self): self.assertTrue( np.allclose(mixture_prior[5, self.alpha_beta], [1.6, 1.2])) + def test_simple_non_contemporaneous(self): + ts = utility_functions.two_tree_ts_n3_non_contemporaneous() + mixture_prior = self.get_mixture_prior_params(ts, 'gamma') + self.assertTrue( + np.allclose(mixture_prior[4, self.alpha_beta], [0.11111, 0.55555])) + + def test_simulated_non_contemporaneous(self): + samples = [ + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=1.0) + ] + ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123) + self.get_mixture_prior_params(ts, 'lognorm') + self.get_mixture_prior_params(ts, 'gamma') + class TestPriorVals(unittest.TestCase): def verify_prior_vals(self, ts, prior_distr): span_data = SpansBySamples(ts) priors = ConditionalCoalescentTimes(None, prior_distr=prior_distr) - priors.add(ts.num_samples, approximate=False) + for total_fixed in span_data.total_fixed_at_0_counts: + priors.add(total_fixed, approximate=False) grid = np.linspace(0, 3, 3) mixture_prior = priors.get_mixture_prior_params(span_data) prior_vals = fill_prior(mixture_prior, grid, ts, prior_distr=prior_distr) @@ -470,6 +496,23 @@ def test_tree_with_unary_nodes(self): self.assertTrue(np.allclose(prior_vals[4], [0, 1, 0.093389])) self.assertTrue(np.allclose(prior_vals[3], [0, 1, 0.011109])) + def test_simple_non_contemporaneous(self): + ts = utility_functions.two_tree_ts_n3_non_contemporaneous() + prior_vals = self.verify_prior_vals(ts, 'gamma') + self.assertEqual(prior_vals.fixed_time(2), ts.node(2).time) + + def test_simulated_non_contemporaneous(self): + samples = [ + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=0), + msprime.Sample(population=0, time=1.0) + ] + ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123) + prior_vals = self.verify_prior_vals(ts, 'gamma') + print(prior_vals.timepoints) + raise + class TestLikelihoodClass(unittest.TestCase): def poisson(self, l, x, normalize=True): @@ -671,102 +714,91 @@ def test_logsumexp_streaming(self): class TestNodeGridValuesClass(unittest.TestCase): - # TODO - needs a few more tests in here def test_init(self): - num_nodes = 5 - ids = np.array([3, 4]) + nonfixed_ids = np.array([3, 2]) timepoints = np.array(range(10)) - store = NodeGridValues(num_nodes, ids, timepoints, fill_value=6) - self.assertEquals(store.grid_data.shape, (len(ids), len(timepoints))) - self.assertEquals(len(store.fixed_data), (num_nodes-len(ids))) + store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=6) + self.assertEquals(store.grid_data.shape, (len(nonfixed_ids), len(timepoints))) self.assertTrue(np.all(store.grid_data == 6)) - self.assertTrue(np.all(store.fixed_data == 6)) + for i in range(np.max(nonfixed_ids)+1): + if i in nonfixed_ids: + self.assertTrue(np.all(store[i] == 6)) + else: + with self.assertRaises(IndexError): + _ = store[i] - ids = np.array([3, 4], dtype=np.int32) - store = NodeGridValues(num_nodes, ids, timepoints, fill_value=5) - self.assertEquals(store.grid_data.shape, (len(ids), len(timepoints))) - self.assertEquals(len(store.fixed_data), num_nodes-len(ids)) - self.assertTrue(np.all(store.fixed_data == 5)) + def test_probability_spaces(self): + nonfixed_ids = np.array([3, 4]) + timepoints = np.array(range(10)) + store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=0.5) + self.assertTrue(np.all(store.grid_data == 0.5)) + store.force_probability_space(LIN) + self.assertTrue(np.all(store.grid_data == 0.5)) + store.force_probability_space(LOG) + self.assertTrue(np.allclose(store.grid_data, np.log(0.5))) + store.force_probability_space(LOG) + self.assertTrue(np.allclose(store.grid_data, np.log(0.5))) + store.force_probability_space(LIN) + self.assertTrue(np.all(store.grid_data == 0.5)) + self.assertRaises(ValueError, store.force_probability_space, "foobar") def test_set_and_get(self): - num_nodes = 5 - grid_size = 2 + timepoints = [0, 1.1] fill = {} - for ids in ([3, 4], []): + for nonfixed_ids in ([3, 4], [0]): np.random.seed(1) - store = NodeGridValues( - num_nodes, np.array(ids, dtype=np.int32), np.array(range(grid_size))) - for i in range(num_nodes): - fill[i] = np.random.random(grid_size if i in ids else None) - store[i] = fill[i] - for i in range(num_nodes): + store = NodeGridValues(timepoints, gridnodes=nonfixed_ids) + for i in range(5): + fill[i] = np.random.random(len(store.timepoints)) + if i in nonfixed_ids: + store[i] = fill[i] + else: + with self.assertRaises(IndexError): + store[i] = fill[i] + for i in nonfixed_ids: self.assertTrue(np.all(fill[i] == store[i])) - self.assertRaises(IndexError, store.__getitem__, num_nodes) def test_bad_init(self): - ids = [3, 4] - self.assertRaises(ValueError, NodeGridValues, 3, np.array(ids), - np.array([0, 1.2, 2])) - self.assertRaises(AttributeError, NodeGridValues, 5, np.array(ids), -1) - self.assertRaises(ValueError, NodeGridValues, 5, np.array([-1]), - np.array([0, 1.2, 2])) + timepoints = [0, 1.2, 2] + nonfixed_ids = [4, 0] + NodeGridValues(timepoints, gridnodes=nonfixed_ids) + # duplicate ids + self.assertRaises(ValueError, NodeGridValues, timepoints, gridnodes=[4, 4, 0]) + # bad ids + self.assertRaises( + ValueError, NodeGridValues, timepoints, gridnodes=np.array([[1, 4], [2, 0]])) + self.assertRaises(OverflowError, NodeGridValues, timepoints, gridnodes=[-1, 4]) + # bad timepoint + self.assertRaises(ValueError, NodeGridValues, [], gridnodes=nonfixed_ids) def test_clone(self): - num_nodes = 10 - grid_size = 2 - ids = [3, 4] - orig = NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size))) + timepoints = [0, 1] + nonfixed_ids = [3, 4] + orig = NodeGridValues(timepoints, gridnodes=nonfixed_ids) orig[3] = np.array([1, 2]) orig[4] = np.array([4, 3]) - orig[0] = 1.5 - orig[9] = 2.5 # test with np.zeros - clone = NodeGridValues.clone_with_new_data(orig, 0) + clone = orig.clone_grid_with_new_data(0) self.assertEquals(clone.grid_data.shape, orig.grid_data.shape) - self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape) self.assertTrue(np.all(clone.grid_data == 0)) - self.assertTrue(np.all(clone.fixed_data == 0)) # test with something else - clone = NodeGridValues.clone_with_new_data(orig, 5) + clone = orig.clone_grid_with_new_data(5) self.assertEquals(clone.grid_data.shape, orig.grid_data.shape) - self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape) self.assertTrue(np.all(clone.grid_data == 5)) - self.assertTrue(np.all(clone.fixed_data == 5)) - # test with different - scalars = np.arange(num_nodes - len(ids)) - clone = NodeGridValues.clone_with_new_data(orig, 0, scalars) - self.assertEquals(clone.grid_data.shape, orig.grid_data.shape) - self.assertEquals(clone.fixed_data.shape, orig.fixed_data.shape) - self.assertTrue(np.all(clone.grid_data == 0)) - self.assertTrue(np.all(clone.fixed_data == scalars)) - clone = NodeGridValues.clone_with_new_data( - orig, np.array([[1, 2], [4, 3]])) - for i in range(num_nodes): - if i in ids: - self.assertTrue(np.all(clone[i] == orig[i])) - else: - self.assertTrue(np.isnan(clone[i])) - clone = NodeGridValues.clone_with_new_data( - orig, np.array([[1, 2], [4, 3]]), 0) - for i in range(num_nodes): - if i in ids: + clone = orig.clone_grid_with_new_data(np.array([[1, 2], [4, 3]])) + for i in range(np.max(nonfixed_ids)+1): + if i in nonfixed_ids: self.assertTrue(np.all(clone[i] == orig[i])) else: - self.assertEquals(clone[i], 0) + self.assertRaises(IndexError, clone.__getitem__, i) def test_bad_clone(self): - num_nodes = 10 - ids = [3, 4] - orig = NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2])) - self.assertRaises( - ValueError, - NodeGridValues.clone_with_new_data, - orig, np.array([[1, 2, 3], [4, 5, 6]])) + ids = np.array([3, 4]) + timepoints = np.array([0, 1.2]) + orig = NodeGridValues(timepoints, gridnodes=ids) self.assertRaises( - ValueError, - NodeGridValues.clone_with_new_data, - orig, 0, np.array([[1, 2], [4, 5]])) + ValueError, orig.clone_grid_with_new_data, np.array([[1, 2, 3], [4, 5, 6]])) class TestAlgorithmClass(unittest.TestCase): @@ -780,7 +812,7 @@ def test_nonmatching_prior_vs_lik_timepoints(self): def test_nonmatching_prior_vs_lik_fixednodes(self): ts1 = utility_functions.single_tree_ts_n3() - ts2 = utility_functions.single_tree_ts_n3_dangling() + ts2 = utility_functions.single_tree_ts_n2_dangling() timepoints = np.array([0, 1.2, 2]) prior = tsdate.build_prior_grid(ts1, timepoints) lls = Likelihoods(ts2, prior.timepoints) @@ -892,7 +924,7 @@ def test_two_tree_mutation_ts(self): self.assertTrue(np.allclose(algo.inside[5], np.array([0, 7.06320034e-11, 1]))) def test_dangling_fails(self): - ts = utility_functions.single_tree_ts_n3_dangling() + ts = utility_functions.single_tree_ts_n2_dangling() print(ts.draw_text()) print("Samples:", ts.samples()) prior = tsdate.build_prior_grid(ts, timepoints=np.array([0, 1.2, 2])) diff --git a/tests/test_inference.py b/tests/test_inference.py index 887d9b9f..bede5f73 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -39,7 +39,7 @@ class TestPrebuilt(unittest.TestCase): Tests for tsdate on prebuilt tree sequences """ def test_dangling_failure(self): - ts = utility_functions.single_tree_ts_n3_dangling() + ts = utility_functions.single_tree_ts_n2_dangling() self.assertRaisesRegexp(ValueError, "dangling", tsdate.date, ts, Ne=1) def test_unary_warning(self): @@ -48,7 +48,7 @@ def test_unary_warning(self): self.assertEqual(len(log.output), 1) self.assertIn("unary nodes", log.output[0]) - def test_fails_with_recombination(self): + def test_fails_with_recombination_clock(self): ts = utility_functions.two_tree_mutation_ts() for probability_space in (LOG, LIN): self.assertRaises( @@ -58,6 +58,12 @@ def test_fails_with_recombination(self): NotImplementedError, tsdate.date, ts, Ne=1, recombination_rate=1, probability_space=probability_space, mutation_rate=1) + def test_non_contemporaneous(self): + ts = utility_functions.two_tree_ts_n3_non_contemporaneous() + theta = 2 + ts = msprime.mutate(ts, rate=theta) + tsdate.date(ts, Ne=1, mutation_rate=theta, probability_space=LIN) + # def test_simple_ts_n2(self): # ts = utility_functions.single_tree_ts_n2() # dated_ts = tsdate.date(ts, Ne=10000) @@ -209,7 +215,8 @@ def test_non_contemporaneous(self): msprime.Sample(population=0, time=0), msprime.Sample(population=0, time=1.0) ] - ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2) + ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123) + print(ts.draw_text()) self.assertRaises(NotImplementedError, tsdate.date, ts, 1, 2) @unittest.skip("YAN to fix") diff --git a/tests/utility_functions.py b/tests/utility_functions.py index 524b2f62..4fad9626 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -44,10 +44,10 @@ def add_grand_mrca(ts): def single_tree_ts_n2(): r""" - Simple case where we have n = 2 and one tree. + Simple case where we have n = 2 and one tree. [] marks a sample 2 / \ - 0 1 + [0] [1] """ nodes = io.StringIO("""\ id is_sample time @@ -69,7 +69,7 @@ def single_tree_ts_n3(): / \ 3 \ / \ \ - 0 1 2 + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -96,7 +96,7 @@ def single_tree_ts_n4(): / \ \ 4 \ \ / \ \ \ - 0 1 2 3 + [0] [1] [2] [3] """ nodes = io.StringIO("""\ id is_sample time @@ -124,7 +124,7 @@ def single_tree_ts_mutation_n3(): / \ 3 x / \ \ - 0 1 2 + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -158,7 +158,7 @@ def single_tree_all_samples_one_mutation_n3(): / \ 3 x / \ \ - 0 1 2 + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -188,16 +188,16 @@ def single_tree_all_samples_one_mutation_n3(): def gils_example_tree(): r""" Simple case where we have n = 3 and one tree. - Number of mutations on each branch are in parentheses. + Mutations marked on each branch by *. 4 / \ - (0) \ - / (4) - 3 \ - / \ \ - (2) (1) \ - / \ \ - 0 1 2 + / \ + / * + 3 * + / \ * + * * * + * \ \ + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -241,7 +241,8 @@ def polytomy_tree_ts(): Simple case where we have n = 3 and a polytomy. 3 /|\ - 0 1 2 + / | \ + [0][1][2] """ nodes = io.StringIO("""\ id is_sample time @@ -264,7 +265,7 @@ def single_tree_ts_internal_n3(): / \ 3 \ / \ \ - 0 1 2 + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -291,7 +292,7 @@ def two_tree_ts(): / \ . | |\ 3 \ . | | \ / \ \ . | | \ - 0 1 2 . 0 1 2 + [0] [1] [2] . [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -313,6 +314,25 @@ def two_tree_ts(): return tskit.load_text(nodes=nodes, edges=edges, strict=False) +def two_tree_ts_n3_non_contemporaneous(): + r""" + Simple case where we have n = 3 and two trees with node 2 ancient. + . 5 + . / \ + 4 . | 4 + / \ . | |\ + 3 [2] . | |[2] + / \ . | | + [0] [1] . [0] [1] + """ + ts = two_tree_ts() + tables = ts.dump_tables() + time = tables.nodes.time + time[2] = time[3] + tables.nodes.time = time + return tables.tree_sequence() + + def single_tree_ts_with_unary(): r""" Simple case where we have n = 3 and some unary nodes. @@ -324,7 +344,7 @@ def single_tree_ts_with_unary(): | | 3 | / \ | - 0 1 2 + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -360,7 +380,7 @@ def two_tree_mutation_ts(): / | . | | | 3 | . | | | / \ | . | | | - 0 1 2 . 0 1 2 + [0] [1] [2] . [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -407,8 +427,7 @@ def two_tree_two_mrcas(): 4 5 | 4 5 / \ / \ | / \ / \ / \ / \ | / \ / \ - | | | | | | | | | - 0 1 2 3 | 0 1 2 3 + [0] [1] [2] [3] | [0] [1] [2] [3] """ nodes = io.StringIO("""\ id is_sample time @@ -438,17 +457,17 @@ def loopy_tree(): Simple case where we have n = 3, 2 trees, three mutations. . 7 . / \ - . / | - . / | - 6 . / 6 - / \ . / / \ - / 5 . / / | - / / \ . / / | - | / \ . | | | - | | \ . | | | - | 4 | . | 4 | - | / \ | . | / \ | - 0 1 2 3 . 0 1 2 3 + . / | + . / | + 6 . / 6 + / \ . / / \ + / 5 . / / | + / / \ . / / | + / | \ . | | | + / | \ . | | | + | 4 | . | 4 | + | / \ | . | / \ | + [0] [1] [2] [3] . [0] [1] [2] [3] """ nodes = io.StringIO("""\ id is_sample time @@ -481,7 +500,7 @@ def single_tree_ts_n3_sample_as_parent(): / \ 3 \ / \ \ - 0 1 2 + [0] [1] [2] """ nodes = io.StringIO("""\ id is_sample time @@ -499,14 +518,58 @@ def single_tree_ts_n3_sample_as_parent(): return tskit.load_text(nodes=nodes, edges=edges, strict=False) -def single_tree_ts_n3_dangling(): - # Mark node 0 as a non-sample node, which should make it dangling - ts = single_tree_ts_n3() - tables = ts.dump_tables() - flags = tables.nodes.flags - flags[0] = flags[0] & (~tskit.NODE_IS_SAMPLE) - tables.nodes.flags = flags - return tables.tree_sequence() +def single_tree_ts_n2_dangling(): + r""" + Simple case where we have n = 2 and one tree. Node 0 is dangling. + 4 + / \ + 3 \ + / \ \ + 0 [1] [2] + """ + nodes = io.StringIO("""\ + id is_sample time + 0 0 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2 + """) + edges = io.StringIO("""\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + """) + return tskit.load_text(nodes=nodes, edges=edges, strict=False) + + +def two_tree_ts_n2_part_dangling(): + r""" + Simple case where we have n = 2 and two trees. Node 0 is dangling in the first tree. + 4 4 + / \ / \ + 3 \ 3 \ + / \ \ \ \ + 0 \ \ 0 \ + \ \ \ \ + [1] [2] [1] [2] + """ + nodes = io.StringIO("""\ + id is_sample time + 0 0 0.5 + 1 1 0 + 2 1 0 + 3 1 1 + 4 0 2 + """) + edges = io.StringIO("""\ + left right parent child + 0 1 3 0 + 0 0.5 3 1 + 0.5 1 0 1 + 0 1 4 2,3 + """) + return tskit.load_text(nodes=nodes, edges=edges, strict=False) def truncate_ts_samples(ts, average_span, random_seed, min_span=5): diff --git a/tsdate/date.py b/tsdate/date.py index 21bbd08f..13ede8fc 100644 --- a/tsdate/date.py +++ b/tsdate/date.py @@ -69,8 +69,11 @@ def tree_is_isolated(tree, node): return tree_num_children(tree, node) == 0 and tree.parent(node) == tskit.NULL -def tree_iterator_len(it): - return it.tree_sequence.num_trees +def to_np_float_array(float_array): + if isinstance(float_array, np.ndarray): + return float_array.astype(FLOAT_DTYPE, copy=False) + else: + return np.array(float_array, dtype=FLOAT_DTYPE) def get_single_root(tree): @@ -928,42 +931,39 @@ class NodeGridValues: :ivar num_nodes: The number of nodes that will be stored in this object :vartype num_nodes: int - :ivar nonfixed_nodes: a (possibly empty) numpy array of unique positive node ids each + :ivar gridnodes: a (possibly empty) numpy array of unique positive node ids each of which must be less than num_nodes. Each will have an array of grid_size associated with it. All others (up to num_nodes) will be associated with a single scalar value instead. - :vartype nonfixed_nodes: numpy.ndarray + :vartype gridnodes: numpy.ndarray :ivar timepoints: Array of time points :vartype timepoints: numpy.ndarray :ivar fill_value: What should we fill the data arrays with to start with :vartype fill_value: numpy.scalar """ - def __init__(self, num_nodes, nonfixed_nodes, timepoints, - fill_value=np.nan, dtype=FLOAT_DTYPE): + def __init__(self, timepoints, *, gridnodes, fill_value=np.nan): """ :param numpy.ndarray grid: The input numpy.ndarray. """ - if nonfixed_nodes.ndim != 1: - raise ValueError("nonfixed_nodes must be a 1D numpy array") - if np.any((nonfixed_nodes < 0) | (nonfixed_nodes >= num_nodes)): - raise ValueError( - "All non fixed node ids must be between zero and the total node number") - grid_size = len(timepoints) if type(timepoints) is np.ndarray else timepoints - self.timepoints = timepoints + self.timepoints = to_np_float_array(timepoints) + self.gridnodes = tskit.safe_np_int_cast(gridnodes, dtype=np.uint64) + if self.gridnodes.ndim != 1: + raise ValueError("Provided gridnodes list must be a 1D numpy array") + if len(np.unique(self.gridnodes)) != len(gridnodes): + raise ValueError("Provided gridnodes list must be unique") + if len(self.timepoints) < 1: + raise ValueError("Must have at least one timepoint") + # Make timepoints immutable so no risk of overwritting them with copy self.timepoints.setflags(write=False) - self.num_nodes = num_nodes - self.nonfixed_nodes = nonfixed_nodes - self.num_nonfixed = len(nonfixed_nodes) - self.grid_data = np.full((self.num_nonfixed, grid_size), fill_value, dtype=dtype) - self.fixed_data = np.full(num_nodes - self.num_nonfixed, fill_value, dtype=dtype) - self.row_lookup = np.empty(num_nodes, dtype=np.int64) - # non-fixed nodes get a positive value, indicating lookup in the grid_data array - self.row_lookup[nonfixed_nodes] = np.arange(self.num_nonfixed) - # fixed nodes get a negative value from -1, indicating lookup in the scalar array - self.row_lookup[np.logical_not(np.isin(np.arange(num_nodes), nonfixed_nodes))] =\ - -np.arange(num_nodes - self.num_nonfixed) - 1 + self.num_gridnodes = len(self.gridnodes) + # This stores likelihoods + self.grid_data = np.full( + (self.num_gridnodes, len(self.timepoints)), fill_value, dtype=FLOAT_DTYPE) + # Map node numbers onto grid_data rows + self.grid_lookup = -np.ones(np.max(self.gridnodes) + np.uint(1), dtype=np.int64) + self.grid_lookup[self.gridnodes] = np.arange(self.num_gridnodes) self.probability_space = LIN def force_probability_space(self, probability_space): @@ -971,28 +971,22 @@ def force_probability_space(self, probability_space): probability_space can be "logarithmic" or "linear": this function will force the current probability space to the desired type """ - descr = self.probability_space, " probabilities into", probability_space, "space" + prob_spaces = [LOG, LIN] + if probability_space not in prob_spaces: + raise ValueError("probability_space must be one of {}".format(prob_spaces)) if probability_space == LIN: if self.probability_space == LIN: pass elif self.probability_space == LOG: self.grid_data = np.exp(self.grid_data) - self.fixed_data = np.exp(self.fixed_data) self.probability_space = LIN - else: - logging.warning("Cannot force", *descr) elif probability_space == LOG: if self.probability_space == LOG: pass elif self.probability_space == LIN: with np.errstate(divide='ignore'): self.grid_data = np.log(self.grid_data) - self.fixed_data = np.log(self.fixed_data) self.probability_space = LOG - else: - logging.warning("Cannot force", *descr) - else: - logging.warning("Cannot force", *descr) def normalize(self): """ @@ -1004,65 +998,44 @@ def normalize(self): elif self.probability_space == LOG: self.grid_data = self.grid_data - rowmax[:, np.newaxis] else: - raise RuntimeError("Probability space is not", LIN, "or", LOG) + raise ValueError("Probability space is not", LIN, "or", LOG) def __getitem__(self, node_id): - index = self.row_lookup[node_id] - if index < 0: - return self.fixed_data[1 + index] + row = self.grid_lookup[node_id] + if row < 0: + raise IndexError("Bad index") else: - return self.grid_data[index, :] + return self.grid_data[row, :] def __setitem__(self, node_id, value): - index = self.row_lookup[node_id] - if index < 0: - self.fixed_data[1 + index] = value + row = self.grid_lookup[node_id] + if row < 0: + raise IndexError("Bad index") else: - self.grid_data[index, :] = value + self.grid_data[row, :] = value - def clone_with_new_data( - self, grid_data=np.nan, fixed_data=None, probability_space=None): + def clone_grid_with_new_data(self, data, probability_space=None): """ - Take the row indices etc from an existing NodeGridValues object and make a new - similar one but with different data. If grid_data is a single number, fill the + Take the gridnodes and lookup etc from an existing NodeGridValues object and make + a new similar one but with different data. If data is a single number, fill the entire data array with that, otherwise assume the data is a numpy array of the - correct size to fill the gridded data. If grid_data is None, fill with NaN - - If fixed_data is None and grid_data is a single number, use the same value as - grid_data for the fixed data values. If fixed_data is None and grid_data is an - array, set the fixed data to np.nan - """ - def fill_fixed(orig, fixed_data): - if type(fixed_data) is np.ndarray: - if orig.fixed_data.shape != fixed_data.shape: - raise ValueError( - "The fixed data array must be the same shape as the original") - return fixed_data - else: - return np.full( - orig.fixed_data.shape, fixed_data, dtype=orig.fixed_data.dtype) + correct size to fill the gridded data. If data is None, fill with NaN + """ new_obj = NodeGridValues.__new__(NodeGridValues) - new_obj.num_nodes = self.num_nodes - new_obj.nonfixed_nodes = self.nonfixed_nodes - new_obj.num_nonfixed = self.num_nonfixed - new_obj.row_lookup = self.row_lookup new_obj.timepoints = self.timepoints - if type(grid_data) is np.ndarray: - if self.grid_data.shape != grid_data.shape: + new_obj.gridnodes = self.gridnodes + new_obj.num_gridnodes = self.num_gridnodes + new_obj.grid_lookup = self.grid_lookup + if type(data) is np.ndarray: + if self.grid_data.shape != data.shape: raise ValueError( "The grid data array must be the same shape as the original") - new_obj.grid_data = grid_data - new_obj.fixed_data = fill_fixed( - self, np.nan if fixed_data is None else fixed_data) + new_obj.grid_data = data else: - if grid_data == 0: # Fast allocation - new_obj.grid_data = np.zeros( - self.grid_data.shape, dtype=self.grid_data.dtype) + if data == 0: # Fast allocation + new_obj.grid_data = np.zeros_like(self.grid_data) else: - new_obj.grid_data = np.full( - self.grid_data.shape, grid_data, dtype=self.grid_data.dtype) - new_obj.fixed_data = fill_fixed( - self, grid_data if fixed_data is None else fixed_data) + new_obj.grid_data = np.full_like(self.grid_data, data) if probability_space is None: new_obj.probability_space = self.probability_space else: @@ -1093,10 +1066,10 @@ def fill_prior(node_parameters, timepoints, ts, *, prior_distr, progress=False): datable_nodes = np.ones(ts.num_nodes, dtype=bool) datable_nodes[ts.samples()] = False datable_nodes = np.where(datable_nodes)[0] - prior_times = NodeGridValues( - ts.num_nodes, - datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32), - timepoints) + # Sort by time + datable_nodes = datable_nodes[ + np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32) + prior_times = NodeGridValues(timepoints, gridnodes=datable_nodes) # TO DO - this can probably be done in an single numpy step rather than a for loop for node in tqdm(datable_nodes, desc="Assign Prior to Each Node", @@ -1286,10 +1259,12 @@ def get_mut_lik_fixed_node(self, edge): mutations_on_edge = self.mut_edges[edge.id] child_time = self.ts.node(edge.child).time - assert child_time == 0 - # Temporary hack - we should really take a more precise likelihood - return self._lik(mutations_on_edge, edge_span(edge), self.timediff, self.theta, - normalize=self.normalize) + timediff = self.timediff - child_time + mask = timediff > 0 + lik = np.full(len(timediff), self.null_constant, dtype=FLOAT_DTYPE) + lik[mask] = self._lik(mutations_on_edge, edge_span(edge), timediff[mask], + self.theta, normalize=self.normalize) + return lik def get_mut_lik_lower_tri(self, edge): """ @@ -1408,7 +1383,9 @@ def get_fixed(self, arr, edge): liks *= self.get_mut_lik_fixed_node(edge) return arr * liks - def scale_geometric(self, fraction, value): + def scale_geometric(self, fraction, value=None): + if value is None: + value = self.identity_constant return value ** fraction @@ -1523,7 +1500,9 @@ def get_fixed(self, arr, edge): log_liks += self.get_mut_lik_fixed_node(edge) return arr + log_liks - def scale_geometric(self, fraction, value): + def scale_geometric(self, fraction, value=None): + if value is None: + value = self.identity_constant return fraction * value @@ -1554,8 +1533,8 @@ class InOutAlgorithms: Contains the inside and outside algorithms """ def __init__(self, prior, lik, *, progress=False): - if (lik.fixednodes.intersection(prior.nonfixed_nodes) or - len(lik.fixednodes) + len(prior.nonfixed_nodes) != lik.ts.num_nodes): + if (lik.fixednodes.intersection(prior.gridnodes) or + len(lik.fixednodes) + len(prior.gridnodes) != lik.ts.num_nodes): raise ValueError( "The prior and likelihood objects disagree on which nodes are fixed") if not np.allclose(lik.timepoints, prior.timepoints): @@ -1563,7 +1542,7 @@ def __init__(self, prior, lik, *, progress=False): "The prior and likelihood objects disagree on the timepoints used") self.prior = prior - self.nonfixed_nodes = prior.nonfixed_nodes + self.gridnodes = prior.gridnodes self.lik = lik self.ts = lik.ts @@ -1635,8 +1614,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None): """ if progress is None: progress = self.progress - inside = self.prior.clone_with_new_data( # store inside matrix values - grid_data=np.nan, fixed_data=self.lik.identity_constant) + inside = self.prior.clone_grid_with_new_data(data=np.nan) if cache_inside: g_i = np.full( (self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant) @@ -1644,7 +1622,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None): # Iterate through the nodes via groupby on parent node for parent, edges in tqdm( self.edges_by_parent_asc(), desc="Inside", - total=inside.num_nonfixed, disable=not progress): + total=inside.num_gridnodes, disable=not progress): """ for each node, find the conditional prob of age at every time in time grid @@ -1658,15 +1636,15 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None): if edge.child in self.fixednodes: # NB: geometric scaling works exactly when all nodes fixed in graph # but is an approximation when times are unknown. - daughter_val = self.lik.scale_geometric(spanfrac, inside[edge.child]) + daughter_val = self.lik.scale_geometric(spanfrac) edge_lik = self.lik.get_fixed(daughter_val, edge) else: inside_values = inside[edge.child] if np.ndim(inside_values) == 0 or np.all(np.isnan(inside_values)): # Child appears fixed, or we have not visited it. Either our # edge order is wrong (bug) or we have hit a dangling node - raise ValueError("The input tree sequence includes " - "dangling nodes: please simplify it") + raise ValueError("Node {} appears to be dangling: please " + "simplify the tree sequence".format(edge.child)) daughter_val = self.lik.scale_geometric( spanfrac, self.lik.make_lower_tri(inside[edge.child])) edge_lik = self.lik.get_inside(daughter_val, edge) @@ -1698,8 +1676,7 @@ def outside_pass(self, *, if not hasattr(self, "inside"): raise RuntimeError("You have not yet run the inside algorithm") - outside = self.inside.clone_with_new_data( - grid_data=0, probability_space=LIN) + outside = self.inside.clone_grid_with_new_data(data=0, probability_space=LIN) for root, span_when_root in self.root_spans.items(): outside[root] = span_when_root / self.spans[root] outside.force_probability_space(self.inside.probability_space) @@ -1740,9 +1717,9 @@ def outside_pass(self, *, outside[child] = self.lik.reduce(val, self.norm[child]) if normalize: outside[child] = self.lik.reduce(val, np.max(val)) - posterior = outside.clone_with_new_data( - grid_data=self.lik.combine(self.inside.grid_data, outside.grid_data), - fixed_data=np.nan) # We should never use the posterior for a fixed node + posterior = outside.clone_grid_with_new_data( + data=self.lik.combine(self.inside.grid_data, outside.grid_data), + ) # We should never use the posterior for a fixed node posterior.normalize() posterior.force_probability_space(probability_space_returned) self.outside = outside @@ -1818,7 +1795,7 @@ def posterior_mean_var(ts, timepoints, posterior, *, fixed_node_set=None): mn_post[fixed_nodes] = ts.tables.nodes.time[fixed_nodes] vr_post[fixed_nodes] = 0 - for row, node_id in zip(posterior.grid_data, posterior.nonfixed_nodes): + for row, node_id in zip(posterior.grid_data, posterior.gridnodes): mn_post[node_id] = np.sum(row * timepoints) / np.sum(row) vr_post[node_id] = (np.sum(row * timepoints ** 2) / np.sum(row) - mn_post[node_id] ** 2) @@ -1859,7 +1836,8 @@ def build_prior_grid(tree_sequence, timepoints=20, *, approximate_prior=False, time slices at which to evaluate node age. :param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as - undated + undated. Currently, only the samples at time 0 are used to create the conditional + coalescent prior. :param int_or_array_like timepoints: The number of quantiles used to create the time slices, or manually-specified time slices as a numpy array :param bool approximate_prior: Whether to use a precalculated approximate prior or @@ -1989,11 +1967,6 @@ def get_dates( :return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date) """ - # Stuff yet to be implemented. These can be deleted once fixed - for sample in tree_sequence.samples(): - if tree_sequence.node(sample).time != 0: - raise NotImplementedError( - "Samples must all be at time 0") fixed_nodes = set(tree_sequence.samples()) # Default to not creating approximate prior unless ts has > 1000 samples @@ -2042,4 +2015,4 @@ def get_dates( raise ValueError( "estimation method must be either 'inside_outside' or 'maximization'") - return mn_post, posterior, prior.timepoints, eps, prior.nonfixed_nodes + return mn_post, posterior, prior.timepoints, eps, prior.gridnodes