Skip to content

Commit

Permalink
Remove fixed values (can be got from the TS)
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Feb 29, 2020
1 parent c6ae1a6 commit 606d4ae
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 70 deletions.
52 changes: 15 additions & 37 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tsdate.date import (SpansBySamples, PriorParams, LIN, LOG,
ConditionalCoalescentTimes, fill_prior, Likelihoods,
LogLikelihoods, LogLikelihoodsStreaming, InOutAlgorithms,
Prior, gamma_approx, constrain_ages_topo) # NOQA
NodeGridValues, gamma_approx, constrain_ages_topo) # NOQA

from tests import utility_functions

Expand Down Expand Up @@ -689,32 +689,24 @@ def test_logsumexp_streaming(self):
np.log(ll_sum)))


class TestPriorClass(unittest.TestCase):
class TestNodeGridValuesClass(unittest.TestCase):
def test_init(self):
nodetimes = np.ones(5)
nonfixed_ids = np.array([3, 2])
timepoints = np.array(range(10))
store = Prior(
timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids, fill_value=6)
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=6)
self.assertEquals(store.grid_data.shape, (len(nonfixed_ids), len(timepoints)))
self.assertEquals(len(store.fixed_times), (len(nodetimes)-len(nonfixed_ids)))
self.assertTrue(np.all(store.grid_data == 6))
self.assertTrue(np.all(store.fixed_times == 1))
for i in range(len(nodetimes)):
for i in range(np.max(nonfixed_ids)+1):
if i in nonfixed_ids:
self.assertTrue(np.all(store[i] == 6))
self.assertRaises(IndexError, store.fixed_time, i)
else:
self.assertEqual(store.fixed_time(i), 1)
with self.assertRaises(IndexError):
_ = store[i]

def test_probability_spaces(self):
nodetimes = np.ones(5)
nonfixed_ids = np.array([3, 4])
timepoints = np.array(range(10))
store = Prior(
timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids, fill_value=0.5)
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids, fill_value=0.5)
self.assertTrue(np.all(store.grid_data == 0.5))
store.force_probability_space(LIN)
self.assertTrue(np.all(store.grid_data == 0.5))
Expand All @@ -727,13 +719,12 @@ def test_probability_spaces(self):
self.assertRaises(ValueError, store.force_probability_space, "foobar")

def test_set_and_get(self):
nodetimes = np.ones(5)
timepoints = [0, 1.1]
fill = {}
for nonfixed_ids in ([3, 4], [0]):
np.random.seed(1)
store = Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
for i in range(len(nodetimes)):
store = NodeGridValues(timepoints, gridnodes=nonfixed_ids)
for i in range(5):
fill[i] = np.random.random(len(store.timepoints))
if i in nonfixed_ids:
store[i] = fill[i]
Expand All @@ -745,33 +736,21 @@ def test_set_and_get(self):

def test_bad_init(self):
timepoints = [0, 1.2, 2]
nodetimes = np.ones(5)
nonfixed_ids = [4, 0]
Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
# ids > nodetimes
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[4, 5])
NodeGridValues(timepoints, gridnodes=nonfixed_ids)
# duplicate ids
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[4, 4, 0])
self.assertRaises(ValueError, NodeGridValues, timepoints, gridnodes=[4, 4, 0])
# bad ids
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=nodetimes,
gridnodes=np.array([[1, 4], [2, 0]]))
self.assertRaises(
OverflowError, Prior, timepoints, nodetimes=nodetimes, gridnodes=[-1, 4])
ValueError, NodeGridValues, timepoints, gridnodes=np.array([[1, 4], [2, 0]]))
self.assertRaises(OverflowError, NodeGridValues, timepoints, gridnodes=[-1, 4])
# bad timepoint
self.assertRaises(
ValueError, Prior, [], nodetimes=nodetimes, gridnodes=nonfixed_ids)
# bad nodetimes
self.assertRaises(
ValueError, Prior, timepoints, nodetimes=[], gridnodes=nonfixed_ids)
self.assertRaises(ValueError, NodeGridValues, [], gridnodes=nonfixed_ids)

def test_clone(self):
timepoints = [0, 1]
nodetimes = np.ones(5)
nonfixed_ids = [3, 4]
orig = Prior(timepoints, nodetimes=nodetimes, gridnodes=nonfixed_ids)
orig = NodeGridValues(timepoints, gridnodes=nonfixed_ids)
orig[3] = np.array([1, 2])
orig[4] = np.array([4, 3])
# test with np.zeros
Expand All @@ -784,17 +763,16 @@ def test_clone(self):
self.assertTrue(np.all(clone.grid_data == 5))

clone = orig.clone_grid_with_new_data(np.array([[1, 2], [4, 3]]))
for i in range(len(nodetimes)):
for i in range(np.max(nonfixed_ids)+1):
if i in nonfixed_ids:
self.assertTrue(np.all(clone[i] == orig[i]))
else:
self.assertRaises(IndexError, clone.__getitem__, i)

def test_bad_clone(self):
nodetimes = np.zeros(10)
ids = np.array([3, 4])
timepoints = np.array([0, 1.2])
orig = Prior(timepoints, nodetimes=nodetimes, gridnodes=ids)
orig = NodeGridValues(timepoints, gridnodes=ids)
self.assertRaises(
ValueError, orig.clone_grid_with_new_data, np.array([[1, 2, 3], [4, 5, 6]]))

Expand Down
34 changes: 1 addition & 33 deletions tsdate/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,37 +1028,6 @@ def clone_grid_with_new_data(self, data, probability_space=None):
return new_obj


class Prior(NodeGridValues):
"""
The same as a NodeGridValues object, but also containing times (not probabilities)
for fixed nodes (i.e. those not in the grid)
`nodetimes` is a list of provisional (fixed) times for all the nodes
`gridnodes` is a numpy array listing the nodes identified as non-fixed: the nodes
in this list will have their `nodetimes` ignored
"""
def __init__(self, timepoints, *, nodetimes, gridnodes, fill_value=np.nan):
gridnodes = tskit.safe_np_int_cast(gridnodes, dtype=np.uint64)
if np.any(gridnodes >= len(nodetimes)):
raise ValueError(
"All non fixed node ids must be less than the total node number")
super().__init__(timepoints, gridnodes=gridnodes, fill_value=fill_value)

# make a list of the fixed indices
fixed = np.ones_like(nodetimes, dtype=bool)
fixed[gridnodes] = False
self.fixed_times = nodetimes[fixed]
# Map node numbers onto grid_data rows
longest_needed = np.max(np.nonzero(fixed)) + 1
self.fixed_lookup = -np.ones(longest_needed, dtype=np.int64)
self.fixed_lookup[fixed[:longest_needed]] = np.arange(len(self.fixed_times))

def fixed_time(self, node_id):
row = self.fixed_lookup[node_id]
if row < 0:
raise IndexError("Bad index")
return self.fixed_times[row]


def fill_prior(distr_parameters, timepoints, ts, nodes_to_date, prior_distr,
progress=False):
"""
Expand All @@ -1069,9 +1038,8 @@ def fill_prior(distr_parameters, timepoints, ts, nodes_to_date, prior_distr,
TODO - what if there is an internal fixed node? Should we truncate
"""
# Sort nodes-to-date by time, as that's the order given when iterating over edges
prior_times = Prior(
prior_times = NodeGridValues(
timepoints,
nodetimes=ts.tables.nodes.time,
gridnodes=nodes_to_date[
np.argsort(ts.tables.nodes.time[nodes_to_date])].astype(np.int32))
if prior_distr == 'lognorm':
Expand Down

0 comments on commit 606d4ae

Please sign in to comment.