Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a sample_id column to ancestors #607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
intersphinx_mapping = {
"https://docs.python.org/3/": None,
"https://tskit.dev/tutorials": None,
"https://numpy.org/doc/stable/": None,
"https://tskit.dev/tskit/docs/stable": None,
"https://tskit.dev/msprime/docs/stable": None,
"https://numcodecs.readthedocs.io/en/stable/": None,
Expand Down
36 changes: 35 additions & 1 deletion tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2193,13 +2193,42 @@ def test_insert_proxy_1_sample(self):
sample_data, sample_ids=[i]
)
assert ancestors.num_ancestors + 1 == ancestors_extra.num_ancestors
inserted = -1
inserted = -1 # inserted one should always be the last
self.assert_ancestor_full_span(ancestors_extra, [inserted])
assert np.array_equal(
ancestors_extra.ancestors_haplotype[inserted],
sample_data.sites_genotypes[:, i][used_sites],
)

def test_insert_proxy_sample_ids(self):
ids = [1, 4]
sd, _ = self.get_example_data(10, 10, 40)
ancestors = tsinfer.generate_ancestors(sd).insert_proxy_samples(
sd, sample_ids=ids
)
inference_sites = np.isin(sd.sites_position[:], ancestors.sites_position[:])
anc_sample_ids = ancestors.ancestors_sample_id[:]
assert np.sum(anc_sample_ids != tskit.NULL) == len(ids)
for sample_id in ids:
assert np.sum(anc_sample_ids == sample_id) == 1
anc_sample = np.where(anc_sample_ids == sample_id)[0][0]
assert ancestors.ancestors_start[anc_sample] == 0
assert ancestors.ancestors_end[anc_sample] == ancestors.num_sites
assert len(ancestors.ancestors_focal_sites[anc_sample]) == 0

haplotype = next(sd.haplotypes([sample_id], sites=inference_sites))[1]
assert np.all(ancestors.ancestors_haplotype[anc_sample] == haplotype)

def test_insert_proxy_different_sample_data(self):
ids = [1, 4]
sd, _ = self.get_example_data(10, 10, 40)
ancestors = tsinfer.generate_ancestors(sd)
sd_copy, _ = self.get_example_data(10, 10, num_ancestors=40)
ancestors_extra = ancestors.insert_proxy_samples(
sd_copy, sample_ids=ids, require_same_sample_data=False
)
assert np.all(ancestors_extra.ancestors_sample_id[:] == tskit.NULL)

def test_insert_proxy_sample_provenance(self):
sample_data, _ = self.get_example_data(10, 10, 40)
ancestors = tsinfer.generate_ancestors(sample_data)
Expand Down Expand Up @@ -2242,6 +2271,8 @@ def test_insert_proxy_time_historical_samples(self):
assert np.array_equal(
ancestors_extra.ancestors_time[-1], historical_sample_time + epsilon
)
assert np.sum(ancestors_extra.ancestors_sample_id[:] != tskit.NULL) == 1
assert ancestors_extra.ancestors_sample_id[-1] == 9

# Test 2 proxies, one historical, specified in different ways / orders
s_ids = np.array([9, 0])
Expand All @@ -2263,6 +2294,9 @@ def test_insert_proxy_time_historical_samples(self):
ancestors_extra.ancestors_haplotype[-1], G[:, 0][used_sites]
)
assert np.array_equal(ancestors_extra.ancestors_time[-1], epsilon)
assert np.sum(ancestors_extra.ancestors_sample_id[:] != tskit.NULL) == 2
assert ancestors_extra.ancestors_sample_id[-2] == 9
assert ancestors_extra.ancestors_sample_id[-1] == 0

def test_insert_proxy_sample_epsilon(self):
sample_data, _ = self.get_example_data(10, 10, 40)
Expand Down
69 changes: 54 additions & 15 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,10 @@ def haplotypes(self, samples=None, sites=None):
``None``, return haplotypes for all sample nodes, otherwise this may be a
numpy array (or array-like) object (converted to dtype=np.int32).
:param array sites: A numpy array of sites to use.


:return: An iterator returning sucessive instances of (sample_id, haplotype).
:rtype: iter(int, numpy.ndarray(dtype=int8))
"""
if samples is None:
samples = np.arange(self.num_samples)
Expand Down Expand Up @@ -2123,6 +2127,7 @@ class Ancestor:
time = attr.ib()
focal_sites = attr.ib()
haplotype = attr.ib()
sample_id = attr.ib()

def __eq__(self, other):
return (
Expand Down Expand Up @@ -2170,7 +2175,7 @@ class AncestorData(DataContainer):
"""

FORMAT_NAME = "tsinfer-ancestor-data"
FORMAT_VERSION = (3, 0)
FORMAT_VERSION = (3, 1)

def __init__(self, sample_data, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -2229,6 +2234,13 @@ def __init__(self, sample_data, **kwargs):
dtype="array:i1",
compressor=self._compressor,
)
self.data.create_dataset(
"ancestors/sample_id",
shape=(0,),
chunks=chunks,
compressor=self._compressor,
dtype=np.int32,
)

self._alloc_ancestor_writer()

Expand All @@ -2244,6 +2256,7 @@ def _alloc_ancestor_writer(self):
"time": self.ancestors_time,
"focal_sites": self.ancestors_focal_sites,
"haplotype": self.ancestors_haplotype,
"sample_id": self.ancestors_sample_id,
},
num_threads=self._num_flush_threads,
)
Expand All @@ -2265,6 +2278,7 @@ def __str__(self):
("ancestors/time", zarr_summary(self.ancestors_time)),
("ancestors/focal_sites", zarr_summary(self.ancestors_focal_sites)),
("ancestors/haplotype", zarr_summary(self.ancestors_haplotype)),
("ancestors/sample_id", zarr_summary(self.ancestors_sample_id)),
]
return super().__str__() + self._format_str(values)

Expand All @@ -2289,6 +2303,9 @@ def data_equal(self, other):
self.ancestors_focal_sites[:], other.ancestors_focal_sites[:]
)
and np_obj_equal(self.ancestors_haplotype[:], other.ancestors_haplotype[:])
and np.array_equal(
self.ancestors_sample_id[:], other.ancestors_sample_id[:]
)
)

@property
Expand Down Expand Up @@ -2340,6 +2357,10 @@ def ancestors_focal_sites(self):
def ancestors_haplotype(self):
return self.data["ancestors/haplotype"]

@property
def ancestors_sample_id(self):
return self.data["ancestors/sample_id"]

@property
def ancestors_length(self):
"""
Expand All @@ -2358,6 +2379,7 @@ def insert_proxy_samples(
*,
sample_ids=None,
epsilon=None,
map_ancestors=False,
allow_mutation=False,
require_same_sample_data=True,
**kwargs,
Expand All @@ -2370,7 +2392,8 @@ def insert_proxy_samples(

A *proxy sample ancestor* is an ancestor based upon a known sample. At
sites used in the full inference process, the haplotype of this ancestor
is identical to that of the sample on which it is based. The time of the
is identical to that of the sample on which it is based, and the
The time of the
ancestor is taken to be a fraction ``epsilon`` older than the sample on
which it is based.

Expand All @@ -2384,11 +2407,11 @@ def insert_proxy_samples(

.. note::

The proxy sample ancestors inserted here will correspond to extra nodes
in the inferred tree sequence. At sites which are not used in the full
The proxy sample ancestors inserted here will end up as extra nodes
in the inferred tree sequence, but at sites which are not used in the full
inference process (e.g. sites unique to a single historical sample),
these proxy sample ancestor nodes may have a different genotype from
their corresponding sample.
it is possible for these proxy sample ancestor nodes to have a different
genotype from their corresponding sample.

:param SampleData sample_data: The :class:`.SampleData` instance
from which to select the samples used to create extra ancestors.
Expand Down Expand Up @@ -2423,7 +2446,8 @@ def insert_proxy_samples(
to ensure that the encoding of alleles in ``sample_data`` matches the
encoding in the current :class:`AncestorData` instance (i.e. that in the
original :class:`.SampleData` instance on which the current ancestors
are based).
are based). Note that in this case, the sample_id is not recorded in the
returned object.
:param \\**kwargs: Further arguments passed to the constructor when creating
the new :class:`AncestorData` instance which will be returned.

Expand Down Expand Up @@ -2521,7 +2545,11 @@ def insert_proxy_samples(
time=proxy_time,
focal_sites=[],
haplotype=haplotype,
sample_id=sample_id
if sample_data.uuid == self.sample_data_uuid
else tskit.NULL,
)

# Add any ancestors remaining in the current instance
while ancestor is not None:
other.add_ancestor(**attr.asdict(ancestor, filter=exclude_id))
Expand Down Expand Up @@ -2603,7 +2631,6 @@ def truncate_ancestors(
start = self.ancestors_start[:]
end = self.ancestors_end[:]
time = self.ancestors_time[:]
focal_sites = self.ancestors_focal_sites[:]
haplotypes = self.ancestors_haplotype[:]
if upper_time_bound > np.max(time) or lower_time_bound > np.max(time):
raise ValueError("Time bounds cannot be greater than older ancestor")
Expand Down Expand Up @@ -2641,16 +2668,12 @@ def truncate_ancestors(
)
start[anc.id] = insert_pos_start
end[anc.id] = insert_pos_end
time[anc.id] = anc.time
focal_sites[anc.id] = anc.focal_sites
haplotypes[anc.id] = anc.haplotype[
insert_pos_start - anc.start : insert_pos_end - anc.start
]
# TODO - record truncation in ancestors' metadata when supported
truncated.ancestors_start[:] = start
truncated.ancestors_end[:] = end
truncated.ancestors_time[:] = time
truncated.ancestors_focal_sites[:] = focal_sites
truncated.ancestors_haplotype[:] = haplotypes
truncated.record_provenance(command="truncate_ancestors")
truncated.finalise()
Expand All @@ -2671,6 +2694,12 @@ def set_inference_sites(self, site_ids):
sites in the sample data file, and the IDs must be in increasing order.

This must be called before the first call to :meth:`.add_ancestor`.

.. note::
To obtain a list of which sites in a sample data or a tree sequence have
been placed into the ancestors file for use in inference, you can apply
:func:`numpy.isin` to the list of positions, e.g.
``np.isin(sample_data.sites_position[:], ancestors.sites_position[:])``
"""
self._check_build_mode()
position = self.sample_data.sites_position[:][site_ids]
Expand All @@ -2679,12 +2708,18 @@ def set_inference_sites(self, site_ids):
array[:] = position
self._num_alleles = self.sample_data.num_alleles(site_ids)

def add_ancestor(self, start, end, time, focal_sites, haplotype):
def add_ancestor(
self, start, end, time, focal_sites, haplotype, sample_id=tskit.NULL
):
"""
Adds an ancestor with the specified haplotype, with ancestral material over the
interval [start:end], that is associated with the specified timepoint and has new
mutations at the specified list of focal sites. Ancestors should be added in time
order, with the oldest first. The id of the added ancestor is returned.
mutations at the specified list of focal sites. If this ancestor is based on a
specific sample from the associated sample_data file (i.e. a historical sample)
then the ``sample_id`` in the sample data file can also be passed as a parameter.

The Ancestors should be added in time order, with the oldest first. The id of
the added ancestor is returned.
"""
self._check_build_mode()
haplotype = tskit.util.safe_np_int_cast(haplotype, dtype=np.int8, copy=True)
Expand Down Expand Up @@ -2714,6 +2749,7 @@ def add_ancestor(self, start, end, time, focal_sites, haplotype):
time=time,
focal_sites=focal_sites,
haplotype=haplotype,
sample_id=sample_id,
)

def finalise(self):
Expand All @@ -2739,6 +2775,7 @@ def ancestor(self, id_):
time=self.ancestors_time[id_],
focal_sites=self.ancestors_focal_sites[id_],
haplotype=self.ancestors_haplotype[id_],
sample_id=self.ancestors_sample_id[id_],
)

def ancestors(self):
Expand All @@ -2750,6 +2787,7 @@ def ancestors(self):
end = self.ancestors_end[:]
time = self.ancestors_time[:]
focal_sites = self.ancestors_focal_sites[:]
sample_id = self.ancestors_sample_id[:]
for j, h in enumerate(chunk_iterator(self.ancestors_haplotype)):
yield Ancestor(
id=j,
Expand All @@ -2758,6 +2796,7 @@ def ancestors(self):
time=time[j],
focal_sites=focal_sites[j],
haplotype=h,
sample_id=sample_id[j],
)


Expand Down