From 0bfee30718d1d3021cb090b793def4b34289fc7c Mon Sep 17 00:00:00 2001 From: Julie Ganeshan Date: Wed, 11 Sep 2024 09:25:23 -0700 Subject: [PATCH] Move logic from TorchX CLI -> API, so MVAI can call it (#955) Summary: Pull Request resolved: https://github.com/pytorch/torchx/pull/955 MVAI's "light" is synchronous - you can immediately see the logs for jobs you start. Only "fire" is asynchronous. TorchX's API, since it's generic, *always* creates jobs that are asynchronous. Therefore, there isn't a built-in interface for "tailing" the stderr of every started process - just for tailing individual replicas of a given role. The TorchX CLI's `torchx run` command **has** implemented this, but its implementation is coupled with the CLI implementations of `torchx run` and `torchx log`. This diff extracts the useful logic into a helper function of the TorchX API Reviewed By: andywag Differential Revision: D62463211 --- torchx/cli/cmd_log.py | 28 +---- torchx/cli/cmd_run.py | 20 ++-- torchx/util/log_tee_helpers.py | 210 +++++++++++++++++++++++++++++++++ 3 files changed, 223 insertions(+), 35 deletions(-) create mode 100644 torchx/util/log_tee_helpers.py diff --git a/torchx/cli/cmd_log.py b/torchx/cli/cmd_log.py index 75bab07fc..0f31441d0 100644 --- a/torchx/cli/cmd_log.py +++ b/torchx/cli/cmd_log.py @@ -23,6 +23,10 @@ from torchx.schedulers.api import Stream from torchx.specs.api import is_started from torchx.specs.builders import make_app_handle +from torchx.util.log_tee_helpers import ( + _find_role_replicas as find_role_replicas, + _prefix_line, +) from torchx.util.types import none_throws @@ -39,19 +43,6 @@ def validate(job_identifier: str) -> None: sys.exit(1) -def _prefix_line(prefix: str, line: str) -> str: - """ - _prefix_line ensure the prefix is still present even when dealing with return characters - """ - if "\r" in line: - line = line.replace("\r", f"\r{prefix}") - if "\n" in line[:-1]: - line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:] - if not line.startswith("\r"): - line = f"{prefix}{line}" - return line - - def print_log_lines( file: TextIO, runner: Runner, @@ -167,17 +158,6 @@ def get_logs( raise threads_exceptions[0] -def find_role_replicas( - app: specs.AppDef, role_name: Optional[str] -) -> List[Tuple[str, int]]: - role_replicas = [] - for role in app.roles: - if role_name is None or role_name == role.name: - for i in range(role.num_replicas): - role_replicas.append((role.name, i)) - return role_replicas - - class CmdLog(SubCommand): def add_arguments(self, subparser: argparse.ArgumentParser) -> None: subparser.add_argument( diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index f2f77c7bb..3246652d3 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -21,7 +21,6 @@ import torchx.specs as specs from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run from torchx.cli.cmd_base import SubCommand -from torchx.cli.cmd_log import get_logs from torchx.runner import config, get_runner, Runner from torchx.runner.config import load_sections from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories @@ -32,6 +31,7 @@ get_builtin_source, get_components, ) +from torchx.util.log_tee_helpers import tee_logs from torchx.util.types import none_throws @@ -288,16 +288,14 @@ def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None: logger.debug(status) def _start_log_thread(self, runner: Runner, app_handle: str) -> threading.Thread: - thread = threading.Thread( - target=get_logs, - kwargs={ - "file": sys.stderr, - "runner": runner, - "identifier": app_handle, - "regex": None, - "should_tail": True, - }, + thread = tee_logs( + dst=sys.stderr, + app_handle=app_handle, + regex=None, + runner=runner, + should_tail=True, + streams=None, + colorize=not sys.stderr.closed and sys.stderr.isatty(), ) - thread.daemon = True thread.start() return thread diff --git a/torchx/util/log_tee_helpers.py b/torchx/util/log_tee_helpers.py new file mode 100644 index 000000000..0bad6f8f2 --- /dev/null +++ b/torchx/util/log_tee_helpers.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +If you're wrapping the TorchX API with your own CLI, these functions can +help show the logs of the job within your CLI, just like +`torchx log` +""" + +import logging +import threading +from queue import Queue +from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING + +from torchx.util.types import none_throws + +if TYPE_CHECKING: + from torchx.runner.api import Runner + from torchx.schedulers.api import Stream + from torchx.specs.api import AppDef + +logger: logging.Logger = logging.getLogger(__name__) + +# A torchX job can have stderr/stdout for many replicas, of many roles +# The scheduler API has functions that allow us to get, +# with unspecified detail, the log lines of a given replica of +# a given role. +# +# So, to neatly tee the results, we: +# 1) Determine every role ID / replica ID pair we want to monitor +# 2) Request the given stderr / stdout / combined streams from them (1 thread each) +# 3) Concatenate each of those streams to a given destination file + + +def _find_role_replicas( + app: "AppDef", + role_name: Optional[str], +) -> List[Tuple[str, int]]: + """ + Enumerate all (role, replica id) pairs in the given AppDef. + Replica IDs are 0-indexed, and range up to num_replicas, + for each role. + If role_name is provided, filters to only that name. + """ + role_replicas = [] + for role in app.roles: + if role_name is None or role_name == role.name: + for i in range(role.num_replicas): + role_replicas.append((role.name, i)) + return role_replicas + + +def _prefix_line(prefix: str, line: str) -> str: + """ + _prefix_line ensure the prefix is still present even when dealing with return characters + """ + if "\r" in line: + line = line.replace("\r", f"\r{prefix}") + if "\n" in line[:-1]: + line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:] + if not line.startswith("\r"): + line = f"{prefix}{line}" + return line + + +def _print_log_lines_for_role_replica( + dst: TextIO, + app_handle: str, + regex: Optional[str], + runner: "Runner", + which_role: str, + which_replica: int, + exceptions: "Queue[Exception]", + should_tail: bool, + streams: Optional["Stream"], + colorize: bool = False, +) -> None: + """ + Helper function that'll run in parallel - one + per monitored replica of a given role. + + Based on print_log_lines .. but not designed for TTY + """ + try: + for line in runner.log_lines( + app_handle, + which_role, + which_replica, + regex, + should_tail=should_tail, + streams=streams, + ): + if colorize: + color_begin = "\033[32m" + color_end = "\033[0m" + else: + color_begin = "" + color_end = "" + prefix = f"{color_begin}{which_role}/{which_replica}{color_end} " + print(_prefix_line(prefix, line), file=dst, end="", flush=True) + except Exception as e: + exceptions.put(e) + raise + + +def _start_threads_to_monitor_role_replicas( + dst: TextIO, + app_handle: str, + regex: Optional[str], + runner: "Runner", + which_role: Optional[str] = None, + should_tail: bool = False, + streams: Optional["Stream"] = None, + colorize: bool = False, +) -> None: + threads = [] + + app = none_throws(runner.describe(app_handle)) + replica_ids = _find_role_replicas(app, role_name=which_role) + + # Holds exceptions raised by all threads, in a thread-safe + # object + exceptions = Queue() + + if not replica_ids: + valid_roles = [role.name for role in app.roles] + raise ValueError( + f"{which_role} is not a valid role name. Available: {valid_roles}" + ) + + for role_name, replica_id in replica_ids: + threads.append( + threading.Thread( + target=_print_log_lines_for_role_replica, + kwargs={ + "dst": dst, + "runner": runner, + "app_handle": app_handle, + "which_role": role_name, + "which_replica": replica_id, + "regex": regex, + "should_tail": should_tail, + "exceptions": exceptions, + "streams": streams, + "colorize": colorize, + }, + daemon=True, + ) + ) + + for t in threads: + t.start() + + for t in threads: + t.join() + + # Retrieve all exceptions, print all except one and raise the first recorded exception + threads_exceptions = [] + while not exceptions.empty(): + threads_exceptions.append(exceptions.get()) + + if len(threads_exceptions) > 0: + for i in range(1, len(threads_exceptions)): + logger.error(threads_exceptions[i]) + + raise threads_exceptions[0] + + +def tee_logs( + dst: TextIO, + app_handle: str, + regex: Optional[str], + runner: "Runner", + should_tail: bool = False, + streams: Optional["Stream"] = None, + colorize: bool = False, +) -> threading.Thread: + """ + Makes a thread, which in turn will start 1 thread per replica + per role, that tees that role-replica's logs to the given + destination buffer. + + You'll need to start and join with this parent thread. + + dst: TextIO to tee the logs into + app_handle: The return value of runner.run() or runner.schedule() + regex: Regex to filter the logs that are tee-d + runner: The Runner you used to schedule the job + should_tail: If true, continue until we run out of logs. Otherwise, just fetch + what's available + streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs + """ + thread = threading.Thread( + target=_start_threads_to_monitor_role_replicas, + kwargs={ + "dst": dst, + "runner": runner, + "app_handle": app_handle, + "regex": None, + "should_tail": True, + "colorize": colorize, + }, + daemon=True, + ) + return thread