Skip to content

Commit

Permalink
Use new algo for truncating priors
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Aug 20, 2022
1 parent da58644 commit de1c751
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
3 changes: 3 additions & 0 deletions tsdate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def normalize(self):
else:
raise RuntimeError("Probability space is not", LIN, "or", LOG)

def is_fixed(self, node_id):
return self.row_lookup[node_id] < 0

def __getitem__(self, node_id):
index = self.row_lookup[node_id]
if index < 0:
Expand Down
57 changes: 28 additions & 29 deletions tsdate/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,14 +1027,10 @@ def shape_scale_from_mean_var(mean, var):

def _truncate_priors(ts, priors, progress=False):
"""
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
sequence
Truncate priors for all nonfixed nodes
so they conform to the age of fixed nodes in the tree sequence
"""
tables = ts.tables
truncate_nodes = priors.nonfixed_node_ids()
# ensure truncate_nodes is ordered by node time
truncate_nodes = truncate_nodes[np.argsort(tables.nodes.time[truncate_nodes])]

fixed_nodes = priors.fixed_node_ids()
fixed_times = tables.nodes.time[fixed_nodes]
Expand All @@ -1050,29 +1046,32 @@ def _truncate_priors(ts, priors, progress=False):
constrained_min_times = np.zeros_like(tables.nodes.time)
# Set the min times of fixed nodes to those in the tree sequence
constrained_min_times[fixed_nodes] = fixed_times
constrained_max_times = np.full_like(constrained_min_times, np.inf)

parents = tables.edges.parent
nd_children = tables.edges.child[np.argsort(parents)]
parents = sorted(parents)
parents_unique = np.unique(parents, return_index=True)
parent_indices = parents_unique[1][np.isin(parents_unique[0], truncate_nodes)]
for index, nd in tqdm(
enumerate(truncate_nodes), desc="Constrain Ages", disable=not progress

# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.concatenate(
(
[0],
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
[tables.edges.num_rows],
)
)
for edges_start, edges_end in zip(
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
):
if index + 1 != len(truncate_nodes):
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
else:
children_index = np.arange(parent_indices[index], ts.num_edges)
children = nd_children[children_index]
time = np.max(constrained_min_times[children])
# The constrained time of the node should be the age of the oldest child
if constrained_min_times[nd] <= time:
constrained_min_times[nd] = time
nearest_time = np.argmin(np.abs(timepoints - time))
lookup_index = priors.row_lookup[int(nd)]
grid_data[lookup_index][:nearest_time] = zero_value
assert np.all(constrained_min_times < constrained_max_times)
parent = tables.edges.parent[edges_start]
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(constrained_min_times[child_ids])
if oldest_child_time > constrained_min_times[parent]:
if priors.is_fixed(parent):
raise ValueError(
"Invalid fixed times: time for"
+ f"fixed node {parent} is younger than some of its descendants"
)
constrained_min_times[parent] = oldest_child_time
if constrained_min_times[parent] > 0 and not priors.is_fixed(parent):
nearest_time = np.argmin(np.abs(timepoints - constrained_min_times[parent]))
grid_data[priors.row_lookup[parent]][:nearest_time] = zero_value

rowmax = grid_data[:, 1:].max(axis=1)
if priors.probability_space == "linear":
Expand Down Expand Up @@ -1132,7 +1131,7 @@ def build_grid(
:param dict node_var_override: is a dict mapping node IDs to a variance value.
Any nodes listed here will be treated as non-fixed nodes whose prior is not
calculated from the conditional coalescent but instead are allocated a prior
whose mean is thenode time in the tree sequence and whose variance is the
whose mean is the node time in the tree sequence and whose variance is the
value in this dictionary. This allows sample nodes to be treated as nonfixed
nodes, and therefore dated. If ``None`` (default) then all sample nodes are
treated as occurring ata fixed time (as if this were an empty dict).
Expand Down

0 comments on commit de1c751

Please sign in to comment.