Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stacktrace reporting to AsyncVectorEnv #1119

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions gymnasium/logger.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
"""Set of functions for logging messages."""

import sys
import warnings
from typing import Optional, Type

from gymnasium.utils import colorize


DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50

min_level = 30

Expand All @@ -20,24 +16,6 @@
warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gymnasium\.")


def set_level(level: int):
"""Set logging threshold on current logger."""
global min_level
min_level = level


def debug(msg: str, *args: object):
"""Logs a debug message to the user."""
if min_level <= DEBUG:
print(f"DEBUG: {msg % args}", file=sys.stderr)


def info(msg: str, *args: object):
"""Logs an info message to the user."""
if min_level <= INFO:
print(f"INFO: {msg % args}", file=sys.stderr)


def warn(
msg: str,
*args: object,
Expand Down Expand Up @@ -68,8 +46,4 @@ def deprecation(msg: str, *args: object):
def error(msg: str, *args: object):
"""Logs an error message if min_level <= ERROR in red on the sys.stderr."""
if min_level <= ERROR:
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)


# DEPRECATED:
setLevel = set_level
warnings.warn(colorize(f"ERROR: {msg % args}", "red"), stacklevel=3)
13 changes: 9 additions & 4 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import multiprocessing
import sys
import time
import traceback
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue
Expand Down Expand Up @@ -623,18 +624,19 @@ def _raise_if_errors(self, successes: list[bool] | tuple[bool]):
num_errors = self.num_envs - sum(successes)
assert num_errors > 0
for i in range(num_errors):
index, exctype, value = self.error_queue.get()
index, exctype, value, trace = self.error_queue.get()

logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
f"Received the following error from Worker-{index} - Shutting it down"
)
logger.error(f"Shutting down Worker-{index}.")
logger.error(f"{trace}")

self.parent_pipes[index].close()
self.parent_pipes[index] = None

if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.")
self._state = AsyncState.DEFAULT
raise exctype(value)

def __del__(self):
Expand Down Expand Up @@ -723,7 +725,10 @@ def _async_worker(
f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
error_type, error_message, _ = sys.exc_info()
trace = traceback.format_exc()

error_queue.put((index, error_type, error_message, trace))
pipe.send((None, False))
finally:
env.close()
89 changes: 89 additions & 0 deletions tests/vector/test_async_vector_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test the `SyncVectorEnv` implementation."""

import re
import warnings
from multiprocessing import TimeoutError

import numpy as np
Expand All @@ -13,6 +14,7 @@
)
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
from gymnasium.vector import AsyncVectorEnv
from tests.testing_env import GenericTestEnv
from tests.vector.testing_utils import (
CustomSpace,
make_custom_space_env,
Expand Down Expand Up @@ -345,3 +347,90 @@ def test_custom_space_async_vector_env_shared_memory():
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True)


def raise_error_reset(self, seed, options):
super(GenericTestEnv, self).reset(seed=seed, options=options)
if seed == 1:
raise ValueError("Error in reset")
return self.observation_space.sample(), {}


def raise_error_step(self, action):
if action >= 1:
raise ValueError(f"Error in step with {action}")

return self.observation_space.sample(), 0, False, False, {}


def test_async_vector_subenv_error():
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 2
)

with warnings.catch_warnings(record=True) as caught_warnings:
envs.reset(seed=[0, 0])
assert len(caught_warnings) == 0

with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in reset"):
envs.reset(seed=[1, 0])

envs.close()

assert len(caught_warnings) == 3
assert (
"Received the following error from Worker-0 - Shutting it down"
in caught_warnings[0].message.args[0]
)
assert (
'in raise_error_reset\n raise ValueError("Error in reset")\nValueError: Error in reset'
in caught_warnings[1].message.args[0]
)
assert (
caught_warnings[2].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)

envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 3
)

with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in step"):
envs.step([0, 1, 2])

envs.close()

assert len(caught_warnings) == 5
# due to variance in the step time, the order of warnings is random
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[0].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[1].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[2].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[3].message.args[0],
)
assert (
caught_warnings[4].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)
Loading