Skip to content

Commit

Permalink
Time Travel Debugging (#323)
Browse files Browse the repository at this point in the history
* First attempt at merge

* Small fixes to time_travel
  • Loading branch information
HowieG authored Aug 1, 2024
1 parent 5bf077f commit de23413
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 9 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,7 @@ cython_debug/

.vscode/
.benchmarks/
.DS_Store
.DS_Store

agentops_time_travel.json
.agentops_time_travel.yaml
41 changes: 41 additions & 0 deletions agentops/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse
from .time_travel import fetch_time_travel_id, set_time_travel_active_state


def main():
parser = argparse.ArgumentParser(description="AgentOps CLI")
subparsers = parser.add_subparsers(dest="command")

timetravel_parser = subparsers.add_parser(
"timetravel", help="Time Travel Debugging commands", aliases=["tt"]
)
timetravel_parser.add_argument(
"branch_name",
type=str,
nargs="?",
help="Given a branch name, fetches the cache file for Time Travel Debugging. Turns on feature by default",
)
timetravel_parser.add_argument(
"--on",
action="store_true",
help="Turns on Time Travel Debugging",
)
timetravel_parser.add_argument(
"--off",
action="store_true",
help="Turns off Time Travel Debugging",
)

args = parser.parse_args()

if args.command in ["timetravel", "tt"]:
if args.branch_name:
fetch_time_travel_id(args.branch_name)
if args.on:
set_time_travel_active_state("on")
if args.off:
set_time_travel_active_state("off")


if __name__ == "__main__":
main()
126 changes: 118 additions & 8 deletions agentops/llm_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
from .event import ActionEvent, ErrorEvent, LLMEvent
from .helpers import check_call_stack_for_agent_id, get_ISO_time
from .log_config import logger
from .event import LLMEvent, ActionEvent, ToolEvent, ErrorEvent
from .helpers import get_ISO_time, check_call_stack_for_agent_id
import inspect
from typing import Optional
import pprint
from .time_travel import (
fetch_completion_override_from_time_travel_cache,
# fetch_prompt_override_from_time_travel_cache,
)

original_func = {}
original_create = None
Expand Down Expand Up @@ -246,7 +255,7 @@ async def async_generator():

# v1.0.0+ responses are objects
try:
self.llm_event.returns = response.model_dump()
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = kwargs["messages"]
self.llm_event.prompt_tokens = response.usage.prompt_tokens
Expand Down Expand Up @@ -408,7 +417,7 @@ def generator():
# Not enough to record StreamedChatResponse_ToolCallsGeneration because the tool may have not gotten called

try:
self.llm_event.returns = response.dict()
self.llm_event.returns = response
self.llm_event.agent_id = check_call_stack_for_agent_id()
self.llm_event.prompt = []
if response.chat_history:
Expand Down Expand Up @@ -614,6 +623,7 @@ async def async_generator():

def override_openai_v1_completion(self):
from openai.resources.chat import completions
from openai.types.chat import ChatCompletion, ChatCompletionChunk

# Store the original method
global original_create
Expand All @@ -624,6 +634,37 @@ def patched_function(*args, **kwargs):
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = None
pydantic_models = (ChatCompletion, ChatCompletionChunk)
for pydantic_model in pydantic_models:
try:
result_model = pydantic_model.model_validate_json(
completion_override
)
break
except Exception as e:
pass

if result_model is None:
logger.error(
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
f"Time Travel: Completion override was:\n"
f"{pprint.pformat(completion_override)}"
)
return None
return self._handle_response_v1_openai(
result_model, kwargs, init_timestamp, session=session
)

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]

# Call the original function with its original arguments
result = original_create(*args, **kwargs)
return self._handle_response_v1_openai(
Expand All @@ -635,17 +676,51 @@ def patched_function(*args, **kwargs):

def override_openai_v1_async_completion(self):
from openai.resources.chat import completions
from openai.types.chat import ChatCompletion, ChatCompletionChunk

# Store the original method
global original_create_async
original_create_async = completions.AsyncCompletions.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments

init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = None
pydantic_models = (ChatCompletion, ChatCompletionChunk)
for pydantic_model in pydantic_models:
try:
result_model = pydantic_model.model_validate_json(
completion_override
)
break
except Exception as e:
pass

if result_model is None:
logger.error(
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
f"Time Travel: Completion override was:\n"
f"{pprint.pformat(completion_override)}"
)
return None
return self._handle_response_v1_openai(
result_model, kwargs, init_timestamp, session=session
)

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]

# Call the original function with its original arguments
result = await original_create_async(*args, **kwargs)
return self._handle_response_v1_openai(
result, kwargs, init_timestamp, session=session
Expand All @@ -656,16 +731,34 @@ async def patched_function(*args, **kwargs):

def override_litellm_completion(self):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format

original_create = litellm.completion

def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
return self._handle_response_v1_openai(
result_model, kwargs, init_timestamp, session=session
)

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]

# Call the original function with its original arguments
result = original_create(*args, **kwargs)
# Note: litellm calls all LLM APIs using the OpenAI format
return self._handle_response_v1_openai(
result, kwargs, init_timestamp, session=session
)
Expand All @@ -674,17 +767,34 @@ def patched_function(*args, **kwargs):

def override_litellm_async_completion(self):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format

original_create = litellm.acompletion
original_create_async = litellm.acompletion

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_create(*args, **kwargs)
# Note: litellm calls all LLM APIs using the OpenAI format

completion_override = fetch_completion_override_from_time_travel_cache(
kwargs
)
if completion_override:
result_model = ChatCompletion.model_validate_json(completion_override)
return self._handle_response_v1_openai(
result_model, kwargs, init_timestamp, session=session
)

# prompt_override = fetch_prompt_override_from_time_travel_cache(kwargs)
# if prompt_override:
# kwargs["messages"] = prompt_override["messages"]

# Call the original function with its original arguments
result = await original_create_async(*args, **kwargs)
return self._handle_response_v1_openai(
result, kwargs, init_timestamp, session=session
)
Expand Down
138 changes: 138 additions & 0 deletions agentops/time_travel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import json
import yaml
from .http_client import HttpClient
from .exceptions import ApiServerException
import os
from .helpers import singleton
from os import environ


@singleton
class TimeTravel:
def __init__(self):
self._completion_overrides_map = {}
self._prompt_override_map = {}

script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
cache_path = os.path.join(parent_dir, "agentops_time_travel.json")

try:
with open(cache_path, "r") as file:
time_travel_cache_json = json.load(file)
self._completion_overrides_map = time_travel_cache_json.get(
"completion_overrides"
)
self._prompt_override_map = time_travel_cache_json.get(
"prompt_override"
)
except FileNotFoundError:
return


def fetch_time_travel_id(ttd_id):
try:
endpoint = environ.get("AGENTOPS_API_ENDPOINT", "https://api.agentops.ai")
payload = json.dumps({"ttd_id": ttd_id}).encode("utf-8")
ttd_res = HttpClient.post(f"{endpoint}/v2/get_ttd", payload)
if ttd_res.code != 200:
raise Exception(
f"Failed to fetch TTD with status code {ttd_res.status_code}"
)

prompt_to_returns_map = {
"completion_overrides": {
(
str({"messages": item["prompt"]["messages"]})
if item["prompt"].get("type") == "chatml"
else str(item["prompt"])
): item["returns"]
for item in ttd_res.body # TODO: rename returns to completion_override
}
}
with open("agentops_time_travel.json", "w") as file:
json.dump(prompt_to_returns_map, file, indent=4)

set_time_travel_active_state(True)
except ApiServerException as e:
manage_time_travel_state(activated=False, error=e)
except Exception as e:
manage_time_travel_state(activated=False, error=e)


def fetch_completion_override_from_time_travel_cache(kwargs):
if not check_time_travel_active():
return

if TimeTravel()._completion_overrides_map:
search_prompt = str({"messages": kwargs["messages"]})
result_from_cache = TimeTravel()._completion_overrides_map.get(search_prompt)
return result_from_cache


def fetch_prompt_override_from_time_travel_cache(kwargs):
if not check_time_travel_active():
return

if TimeTravel()._prompt_override_map:
search_prompt = str({"messages": kwargs["messages"]})
result_from_cache = TimeTravel()._prompt_override_map.get(search_prompt)
return json.loads(result_from_cache)


def check_time_travel_active():
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
config_file_path = os.path.join(parent_dir, ".agentops_time_travel.yaml")

with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
if config.get("Time_Travel_Debugging_Active", True):
manage_time_travel_state(activated=True)
return True

return False


def set_time_travel_active_state(is_active: bool):
config_path = ".agentops_time_travel.yaml"
try:
with open(config_path, "r") as config_file:
config = yaml.safe_load(config_file) or {}
except FileNotFoundError:
config = {}

config["Time_Travel_Debugging_Active"] = is_active

with open(config_path, "w") as config_file:
try:
yaml.dump(config, config_file)
except:
print(
f"🖇 AgentOps: Unable to write to {config_path}. Time Travel not activated"
)
return

if is_active:
manage_time_travel_state(activated=True)
print("AgentOps: Time Travel Activated")
else:
manage_time_travel_state(activated=False)
print("🖇 AgentOps: Time Travel Deactivated")


def add_time_travel_terminal_indicator():
print(f"🖇️ ⏰ | ", end="")


def reset_terminal():
print("\033[0m", end="")


def manage_time_travel_state(activated=False, error=None):
if activated:
add_time_travel_terminal_indicator()
else:
reset_terminal()
if error is not None:
print(f"🖇 Deactivating Time Travel. Error with configuration: {error}")
Loading

0 comments on commit de23413

Please sign in to comment.