Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rematch recombinants -> JSON #362

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 170 additions & 146 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import collections
import concurrent
import concurrent.futures as cf
import logging
import itertools
import platform
import pathlib
import sys
import contextlib
import dataclasses
import datetime
import time
import os
from typing import List

import numpy as np
import tqdm
Expand Down Expand Up @@ -572,144 +575,6 @@ def tally_lineages(ts, metadata, verbose):
df.to_csv(sys.stdout, sep="\t", index=False)


def examine_recombinant(work):
base_ts = tszip.load(work.ts_path)
# NOTE: this is needed because we have to have all the sites in the trees
# for tsinfer matching to work in the reverse direction. There is the
# possibility of subtle differences in the match path because of this.
# We probably won't offer this interface anyway for long, though, and
# the forward/backward in the inference
base_ts = sc2ts.pad_sites(base_ts)
with contextlib.ExitStack() as exit_stack:
alignment_store = exit_stack.enter_context(
sc2ts.AlignmentStore(work.alignments)
)
metadata_db = exit_stack.enter_context(sc2ts.MetadataDb(work.metadata))
metadata_matches = list(
metadata_db.query(f"SELECT * FROM samples WHERE strain=='{work.strain}'")
)
samples = sc2ts.preprocess(
metadata_matches,
base_ts,
metadata_matches[0]["date"],
alignment_store,
show_progress=False,
)
try:
sc2ts.match_recombinants(
samples,
base_ts,
num_mismatches=work.num_mismatches,
show_progress=False,
num_threads=0,
)
except Exception as e:
print("ERROR in matching", samples[0].strain)
raise e
return samples[0]


@dataclasses.dataclass(frozen=True)
class Work:
strain: str
ts_path: str
num_mismatches: int
alignments: str
metadata: str
sample: int
recombinant: int


@click.command()
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
@click.argument("metadata", type=click.Path(exists=True, dir_okay=False))
@click.argument("tsz_prefix")
@click.argument("base_date")
@click.argument("out_tsz")
@click.option("--num-mismatches", default=3, type=float, help="num-mismatches")
@click.option("-v", "--verbose", count=True)
def annotate_recombinants(
alignments, metadata, tsz_prefix, base_date, out_tsz, num_mismatches, verbose
):
"""
Update recombinant nodes in the specified trees with additional
information about the matching process.
"""
setup_logging(verbose)
ts = tszip.load(tsz_prefix + base_date + ".ts.tsz")

recomb_samples = sc2ts.utils.get_recombinant_samples(ts)
mismatches = [num_mismatches]

work = []
for recombinant, sample in recomb_samples.items():
md = ts.node(sample).metadata
date = md["date"]
previous_date = datetime.date.fromisoformat(date)
previous_date -= datetime.timedelta(days=1)
tsz_path = f"{tsz_prefix}{previous_date}.ts"
for num_mismatches in mismatches:
work.append(
Work(
strain=md["strain"],
ts_path=tsz_path,
num_mismatches=num_mismatches,
alignments=alignments,
metadata=metadata,
sample=sample,
recombinant=recombinant,
)
)

results = {}
# for item in work:
# sample = examine_recombinant(item)
# results[item.recombinant] = sample
with concurrent.futures.ProcessPoolExecutor(max_workers=8) as executor:
future_to_work = {
executor.submit(examine_recombinant, item): item for item in work
}

bar = tqdm.tqdm(
concurrent.futures.as_completed(future_to_work), total=len(work)
)
for future in bar:
try:
data = future.result()
except Exception as exc:
print(f"Work item: {future_to_work[future]} raised exception!")
print(exc)
work = future_to_work[future]
results[work.recombinant] = data

tables = ts.dump_tables()
# This is probably very inefficient as we're writing back the metadata column
# many times
for recomb_node, sample in tqdm.tqdm(results.items(), desc="Updating metadata"):
row = tables.nodes[recomb_node]

hmm_md = [
{
"direction": "forward",
"path": [x.asdict() for x in sample.forward_path],
"mutations": [x.asdict() for x in sample.forward_mutations],
},
{
"direction": "reverse",
"path": [x.asdict() for x in sample.reverse_path],
"mutations": [x.asdict() for x in sample.reverse_mutations],
},
]
d = row.metadata
d["sc2ts"] = {"hmm": hmm_md}
# print(json.dumps(hmm_md, indent=2))
tables.nodes[recomb_node] = row.replace(metadata=d)

ts = tables.tree_sequence()
logging.info("Compressing output")
tszip.compress(ts, out_tsz)


@dataclasses.dataclass(frozen=True)
class HmmRun:
strain: str
Expand All @@ -726,9 +591,44 @@ def asjson(self):
return json.dumps(self.asdict())


@dataclasses.dataclass(frozen=True)
class MatchWork:
ts_path: str
samples: List
num_mismatches: int
direction: str


def _match_worker(work):
ts = tszip.load(work.ts_path)
mu, rho = sc2ts.solve_num_mismatches(work.num_mismatches)
matches = sc2ts.match_tsinfer(
samples=work.samples,
ts=ts,
mu=mu,
rho=rho,
num_threads=0,
show_progress=False,
# Maximum possible precision
likelihood_threshold=1e-200,
mirror_coordinates=work.direction == "reverse",
)
runs = []
for hmm_match, sample in zip(matches, work.samples):
runs.append(
HmmRun(
strain=sample.strain,
num_mismatches=work.num_mismatches,
direction=work.direction,
match=hmm_match,
)
)
return runs


@click.command()
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("alignments_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("strains", nargs=-1)
@click.option("--num-mismatches", default=3, type=int, help="num-mismatches")
@click.option(
Expand All @@ -747,8 +647,8 @@ def asjson(self):
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def run_match(
alignments,
ts,
alignments_path,
ts_path,
strains,
num_mismatches,
direction,
Expand All @@ -761,13 +661,13 @@ def run_match(
Run matches for a specified set of strains, outputting details to stdout as JSON.
"""
setup_logging(verbose, log_file)
ts = tszip.load(ts)
ts = tszip.load(ts_path)
if len(strains) == 0:
return
progress_title = "Match"
samples = sc2ts.preprocess(
list(strains),
alignments,
alignments_path,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
Expand All @@ -776,6 +676,7 @@ def run_match(
for sample in samples:
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")

mu, rho = sc2ts.solve_num_mismatches(num_mismatches)
matches = sc2ts.match_tsinfer(
samples=samples,
Expand All @@ -800,6 +701,129 @@ def run_match(
print(run.asjson())


def find_previous_date_path(date, path_pattern):
"""
Find the path with the most-recent date to the specified one
matching the given pattern.
"""
date = datetime.date.fromisoformat(date)
for j in range(1, 30):
previous_date = date - datetime.timedelta(days=j)
path = pathlib.Path(path_pattern.format(previous_date))
logger.debug(f"Trying {path}")
if path.exists():
break
else:
raise ValueError(
f"No path exists for pattern {path_pattern} starting at {date}"
)
return path


@click.command()
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("path_pattern")
@click.option(
"-k",
"--num-mismatches",
default=[3],
type=int,
multiple=True,
help="num-mismatches",
)
@click.option(
"--num-threads",
default=0,
type=int,
help="Number of match threads (default to one)",
)
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def run_rematch_recombinants(
alignments,
ts,
path_pattern,
num_mismatches,
num_threads,
progress,
verbose,
log_file,
):
setup_logging(verbose, log_file)
ts = tszip.load(ts)
# This is a map of recombinant node to the samples involved in
# the original causal sample group.
recombinant_strains = sc2ts.get_recombinant_strains(ts)
logger.info(
f"Got {len(recombinant_strains)} recombinants and "
f"{sum(len(v) for v in recombinant_strains.values())} strains"
)

# Map recombinants to originating date
recombinant_to_path = {}
strain_to_recombinant = {}
all_strains = []
for u, strains in recombinant_strains.items():
date_added = ts.node(u).metadata["sc2ts"]["date_added"]
base_ts_path = find_previous_date_path(date_added, path_pattern)
recombinant_to_path[u] = base_ts_path
for strain in strains:
strain_to_recombinant[strain] = u
all_strains.append(strain)

progress_title = "Recomb"
samples = sc2ts.preprocess(
all_strains,
alignments,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
num_workers=num_threads,
)

recombinant_to_samples = collections.defaultdict(list)
for sample in samples:
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")
recombinant = strain_to_recombinant[sample.strain]
recombinant_to_samples[recombinant].append(sample)

work = []
for recombinant, samples in recombinant_to_samples.items():
for direction in ["forward", "reverse"]:
for k in num_mismatches:
work.append(
MatchWork(
recombinant_to_path[recombinant],
samples,
num_mismatches=k,
direction=direction,
)
)

bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work))

def output(hmm_runs):
bar.update()
for run in hmm_runs:
print(run.asjson())

results = []
if num_threads == 0:
for w in work:
hmm_runs = _match_worker(w)
output(hmm_runs)
else:
with cf.ProcessPoolExecutor(num_threads) as executor:
futures = [executor.submit(_match_worker, w) for w in work]
for future in cf.as_completed(futures):
hmm_runs = future.result()
output(hmm_runs)
bar.close()


@click.version_option(core.__version__)
@click.group()
def cli():
Expand All @@ -819,6 +843,6 @@ def cli():
cli.add_command(list_dates)
cli.add_command(extend)
cli.add_command(validate)
cli.add_command(annotate_recombinants)
cli.add_command(run_match)
cli.add_command(run_rematch_recombinants)
cli.add_command(tally_lineages)
Loading