Skip to content

Commit

Permalink
Merge pull request #265 from jeromekelleher/rejig-progress-ui
Browse files Browse the repository at this point in the history
Rejig progress UI for inference
  • Loading branch information
jeromekelleher authored Sep 13, 2024
2 parents 060ae61 + 5a3250e commit 14accb2
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 88 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ dependencies = [
"scipy",
"click",
"zarr<2.18",
"humanize",
"resource",
]
dynamic = ["version"]

Expand Down
1 change: 0 additions & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ metadata=$datadir/metadata.db
matches=$resultsdir/matches.db

dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31 | head -n 14`
echo $dates

options="--num-threads $num_threads -vv -l $logfile "
# options+="--max-submission-delay $max_submission_delay "
Expand Down
12 changes: 6 additions & 6 deletions sc2ts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

from . core import __version__
from .core import __version__

# FIXME
from . core import *
from . alignments import *
from . metadata import *
from . inference import *
from .core import *
from .alignments import *
from .metadata import *
from .inference import *
from .validation import *
51 changes: 42 additions & 9 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import concurrent
import logging
import platform
import random
import pathlib
import sys
import contextlib
import dataclasses
import datetime
import pickle
import time
import os

import numpy as np
import tqdm
Expand All @@ -18,14 +17,42 @@
import tsinfer
import click
import daiquiri
import humanize
import pandas as pd

try:
import resource
except ImportError:
resource = None # resource.getrusage absent on windows, so skip outputting max mem

import sc2ts
from . import core
from . import inference
from . import utils

logger = logging.getLogger(__name__)

__before = time.time()


def summarise_usage():
wall_time = time.time() - __before
user_time = os.times().user
sys_time = os.times().system
if resource is None:
# Don't report max memory on Windows. We could do this using the psutil lib, via
# psutil.Process(os.getpid()).get_ext_memory_info().peak_wset if demand exists
maxmem_str = "?"
else:
max_mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if sys.platform != "darwin":
max_mem *= 1024 # Linux and other OSs (e.g. freeBSD) report maxrss in kb
maxmem_str = "; max_memory=" + humanize.naturalsize(max_mem, binary=True)
return (
f"Done in {humanize.naturaldelta(wall_time)}; "
f"elapsed={wall_time:.2f}; user={user_time:.2f}; sys={sys_time:.2f}"
+ maxmem_str
)


def get_environment():
"""
Expand Down Expand Up @@ -207,16 +234,18 @@ def initialise(ts, match_db, additional_problematic_sites, verbose, log_file):
f"Excluding additional {len(additional_problematic)} problematic sites"
)

base_ts = inference.initial_ts(additional_problematic)
base_ts = sc2ts.initial_ts(additional_problematic)
base_ts.dump(ts)
logger.info(f"New base ts at {ts}")
inference.MatchDb.initialise(match_db)
sc2ts.MatchDb.initialise(match_db)


@click.command()
@click.argument("metadata", type=click.Path(exists=True, dir_okay=False))
@click.option("--counts/--no-counts", default=False)
@click.option("--after", default="1900-01-01", help="show dates after the specified value")
@click.option(
"--after", default="1900-01-01", help="show dates after the specified value"
)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def list_dates(metadata, counts, after, verbose, log_file):
Expand Down Expand Up @@ -342,7 +371,7 @@ def extend(
abort=True,
)
match_db.delete_newer(date)
ts_out = inference.extend(
ts_out = sc2ts.extend(
alignment_store=alignment_store,
metadata_db=metadata_db,
base_ts=base,
Expand All @@ -358,6 +387,10 @@ def extend(
show_progress=progress,
)
add_provenance(ts_out, output_ts)
resource_usage = f"{date}:{summarise_usage()}"
logger.info(resource_usage)
if progress:
print(resource_usage, file=sys.stderr)


@click.command()
Expand All @@ -372,7 +405,7 @@ def validate(alignment_db, ts_file, verbose):

ts = tszip.load(ts_file)
with sc2ts.AlignmentStore(alignment_db) as alignment_store:
inference.validate(ts, alignment_store, show_progress=True)
sc2ts.validate(ts, alignment_store, show_progress=True)


@click.command()
Expand Down
115 changes: 43 additions & 72 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,36 @@
logger = logging.getLogger(__name__)


def get_progress(iterable, date, phase, show_progress, total=None):
bar_format = (
"{desc:<22}{percentage:3.0f}%|{bar}"
"| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
)
return tqdm.tqdm(
iterable,
total=total,
desc=f"{date}:{phase}",
disable=not show_progress,
bar_format=bar_format,
dynamic_ncols=True,
smoothing=0.01,
unit_scale=True,
)


class TsinferProgressMonitor(tsinfer.progress.ProgressMonitor):
def __init__(self, date, phase, *args, **kwargs):
self.date = date
self.phase = phase
super().__init__(*args, **kwargs)

def get(self, key, total):
self.current_instance = get_progress(
None, self.date, phase=self.phase, show_progress=self.enabled, total=total
)
return self.current_instance


class MatchDb:
def __init__(self, path):
uri = f"file:{path}"
Expand Down Expand Up @@ -287,60 +317,6 @@ def increment_time(date, ts):
return tables.tree_sequence()


def _validate_samples(ts, samples, alignment_store, show_progress):
strains = [ts.node(u).metadata["strain"] for u in samples]
G = np.zeros((ts.num_sites, len(samples)), dtype=np.int8)
keep_sites = ts.sites_position.astype(int)
strains_iter = enumerate(strains)
with tqdm.tqdm(
strains_iter,
desc="Read",
total=len(strains),
position=1,
leave=False,
disable=not show_progress,
) as bar:
for j, strain in bar:
ma = alignments.encode_and_mask(alignment_store[strain])
G[:, j] = ma.alignment[keep_sites]

vars_iter = ts.variants(samples=samples, alleles=tuple(core.ALLELES))
with tqdm.tqdm(
vars_iter,
desc="Check",
total=ts.num_sites,
position=1,
leave=False,
disable=not show_progress,
) as bar:
for var in bar:
original = G[var.site.id]
non_missing = original != -1
if not np.all(var.genotypes[non_missing] == original[non_missing]):
raise ValueError("Data mismatch")


def validate(ts, alignment_store, show_progress=False):
"""
Check that all the samples in the specified tree sequence are correctly
representing the original alignments.
"""
samples = ts.samples()[1:]
chunk_size = 10**3
offset = 0
num_chunks = ts.num_samples // chunk_size
for chunk_index in tqdm.tqdm(
range(num_chunks), position=0, disable=not show_progress
):
chunk = samples[offset : offset + chunk_size]
offset += chunk_size
_validate_samples(ts, chunk, alignment_store, show_progress)

if ts.num_samples % chunk_size != 0:
chunk = samples[offset:]
_validate_samples(ts, chunk, alignment_store, show_progress)


@dataclasses.dataclass
class Sample:
strain: str
Expand Down Expand Up @@ -437,6 +413,8 @@ def match_samples(
likelihood_threshold=likelihood_threshold,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match({k})",
)

exceeding_threshold = []
Expand All @@ -461,6 +439,8 @@ def match_samples(
rho=rho,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match(F)",
)
recombinants = []
for sample in run_batch:
Expand Down Expand Up @@ -509,11 +489,7 @@ def preprocess(samples_md, base_ts, date, alignment_store, show_progress=False):
problematic_sites = core.get_problematic_sites()

samples = []
with tqdm.tqdm(
samples_md,
desc=f"Preprocess",
disable=not show_progress,
) as bar:
with get_progress(samples_md, date, "preprocess", show_progress) as bar:
for md in bar:
strain = md["strain"]
try:
Expand Down Expand Up @@ -639,6 +615,7 @@ def extend(
date=date,
min_group_size=1,
show_progress=show_progress,
phase="add(close)",
)

logger.info("Looking for retrospective matches")
Expand All @@ -652,6 +629,7 @@ def extend(
min_group_size=min_group_size,
min_different_dates=3, # TODO parametrize
show_progress=show_progress,
phase="add(retro)",
)
return update_top_level_metadata(ts, date)

Expand Down Expand Up @@ -785,6 +763,7 @@ def add_matching_results(
min_group_size=1,
min_different_dates=1,
show_progress=False,
phase=None,
):
logger.info(f"Querying match DB WHERE: {where_clause}")
samples = match_db.get(where_clause)
Expand Down Expand Up @@ -814,13 +793,7 @@ def add_matching_results(

attach_nodes = []
added_samples = []

with tqdm.tqdm(
grouped_matches.items(),
desc=f"Build:{date}",
total=len(grouped_matches),
disable=not show_progress,
) as bar:
with get_progress(list(grouped_matches.items()), date, phase, show_progress) as bar:
for (path, reversions), match_samples in bar:
different_dates = set(sample.date for sample in match_samples)
# TODO (1) add group ID from hash of samples (2) better logging of path
Expand Down Expand Up @@ -1283,6 +1256,8 @@ def match_tsinfer(
likelihood_threshold=None,
num_threads=0,
show_progress=False,
date=None,
phase=None,
mirror_coordinates=False,
):
if len(samples) == 0:
Expand All @@ -1302,12 +1277,8 @@ def match_tsinfer(
# Let's say a double break with 5 mutations is the most unlikely thing
# we're interested in solving for exactly.
likelihood_threshold = rho**2 * mu**5
pm = tsinfer.inference._get_progress_monitor(
show_progress,
generate_ancestors=False,
match_ancestors=False,
match_samples=False,
)

pm = TsinferProgressMonitor(date, phase, enabled=show_progress)

# This is just working around tsinfer's input checking logic. The actual value
# we're incrementing by has no effect.
Expand Down
60 changes: 60 additions & 0 deletions sc2ts/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np

import tqdm

from . import alignments
from . import core


def _validate_samples(ts, samples, alignment_store, show_progress):
strains = [ts.node(u).metadata["strain"] for u in samples]
G = np.zeros((ts.num_sites, len(samples)), dtype=np.int8)
keep_sites = ts.sites_position.astype(int)
strains_iter = enumerate(strains)
with tqdm.tqdm(
strains_iter,
desc="Read",
total=len(strains),
position=1,
leave=False,
disable=not show_progress,
) as bar:
for j, strain in bar:
ma = alignments.encode_and_mask(alignment_store[strain])
G[:, j] = ma.alignment[keep_sites]

vars_iter = ts.variants(samples=samples, alleles=tuple(core.ALLELES))
with tqdm.tqdm(
vars_iter,
desc="Check",
total=ts.num_sites,
position=1,
leave=False,
disable=not show_progress,
) as bar:
for var in bar:
original = G[var.site.id]
non_missing = original != -1
if not np.all(var.genotypes[non_missing] == original[non_missing]):
raise ValueError("Data mismatch")


def validate(ts, alignment_store, show_progress=False):
"""
Check that all the samples in the specified tree sequence are correctly
representing the original alignments.
"""
samples = ts.samples()[1:]
chunk_size = 10**3
offset = 0
num_chunks = ts.num_samples // chunk_size
for chunk_index in tqdm.tqdm(
range(num_chunks), position=0, disable=not show_progress
):
chunk = samples[offset : offset + chunk_size]
offset += chunk_size
_validate_samples(ts, chunk, alignment_store, show_progress)

if ts.num_samples % chunk_size != 0:
chunk = samples[offset:]
_validate_samples(ts, chunk, alignment_store, show_progress)

0 comments on commit 14accb2

Please sign in to comment.