Skip to content

Commit

Permalink
Merge pull request #224 from leeeizhang/lei/redesign-chat
Browse files Browse the repository at this point in the history
[MRG] redesign `mle chat`
  • Loading branch information
huangyz0918 authored Sep 27, 2024
2 parents 9cfbd76 + 728a879 commit 6f378d6
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 74 deletions.
1 change: 1 addition & 0 deletions mle/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .planner import *
from .summarizer import *
from .reporter import *
from .chat import *
129 changes: 129 additions & 0 deletions mle/agents/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import sys
import json
from rich.console import Console

from mle.function import *
from mle.utils import get_config, print_in_box, WorkflowCache


class ChatAgent:

def __init__(self, model, working_dir='.', console=None):
"""
ChatAgent assists users with planning and debugging ML projects.
Args:
model: The machine learning model used for generating responses.
"""
config_data = get_config()

self.model = model
self.chat_history = []
self.working_dir = working_dir
self.cache = WorkflowCache(working_dir, 'baseline')

self.console = console
if not self.console:
self.console = Console()

self.sys_prompt = f"""
You are a programmer working on an Machine Learning task using Python.
You are currently working on: {self.working_dir}.
Your can leverage your capabilities by using the specific functions listed below:
1. Creating project structures based on the user requirement using function `create_directory`.
2. Writing clean, efficient, and well-documented code using function `create_file` and `write_file`.
3. Exam the project to re-use the existing code snippets as much as possible, you may need to use
functions like `list_files`, `read_file` and `write_file`.
4. Writing the code into the file when creating new files, do not create empty files.
5. Use function `preview_csv_data` to preview the CSV data if the task include CSV data processing.
6. Decide whether the task requires execution and debugging before moving to the next or not.
7. Generate the commands to run and test the current task, and the dependencies list for this task.
8. You only write Python scripts, don't write Jupiter notebooks which require interactive execution.
"""
self.search_prompt = """
9. Performing web searches use function `web_search` to get up-to-date information or additional context.
"""

self.functions = [
schema_read_file,
schema_create_file,
schema_write_file,
schema_list_files,
schema_create_directory,
schema_search_arxiv,
schema_search_papers_with_code,
schema_web_search,
schema_execute_command,
schema_preview_csv_data
]

if config_data.get('search_key'):
self.functions.append(schema_web_search)
self.sys_prompt += self.search_prompt

if not self.cache.is_empty():
dataset = self.cache.resume_variable("dataset")
ml_requirement = self.cache.resume_variable("ml_requirement")
advisor_report = self.cache.resume_variable("advisor_report")
self.sys_prompt += f"""
The overall project information: \n
{'Dataset: ' + dataset if dataset else ''} \n
{'Requirement: ' + ml_requirement if ml_requirement else ''} \n
{'Advisor: ' + advisor_report if advisor_report else ''} \n
"""

self.chat_history.append({"role": 'system', "content": self.sys_prompt})

def greet(self):
"""
Generate a greeting message to the user, including inquiries about the project's purpose and
an overview of the support provided. This initializes a collaborative tone with the user.
Returns:
str: The generated greeting message.
"""
system_prompt = """
You are a Chatbot designed to collaborate with users on planning and debugging ML projects.
Your goal is to provide concise and friendly greetings within 50 words, including:
1. Infer about the project's purpose or objective.
2. Summarize the previous conversations if it existed.
2. Offering a brief overview of the assistance and support you can provide to the user, such as:
- Helping with project planning and management.
- Assisting with debugging and troubleshooting code.
- Offering advice on best practices and optimization techniques.
- Providing resources and references for further learning.
Make sure your greeting is inviting and sets a positive tone for collaboration.
"""
self.chat_history.append({"role": "system", "content": system_prompt})
greets = self.model.query(
self.chat_history,
function_call='auto',
functions=self.functions,
)

self.chat_history.append({"role": "assistant", "content": greets})
return greets

def chat(self, user_prompt):
"""
Handle the response from the model streaming.
The stream mode is integrative with the model streaming function, we don't
need to set it into the JSON mode.
Args:
user_prompt: the user prompt.
"""
text = ''
self.chat_history.append({"role": "user", "content": user_prompt})
for content in self.model.stream(
self.chat_history,
function_call='auto',
functions=self.functions,
):
if content:
text += content
yield text

self.chat_history.append({"role": "assistant", "content": text})
32 changes: 0 additions & 32 deletions mle/agents/coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,35 +205,3 @@ def interact(self, task_dict: dict):
)
print_in_box(process_summary(self.code_summary), self.console, title="MLE Developer", color="cyan")
return self.code_summary

def chat(self, user_prompt):
"""
Handle the response from the model streaming.
The stream mode is integrative with the model streaming function, we don't
need to set it into the JSON mode.
Args:
user_prompt: the user prompt.
"""
text = ''
self.chat_history.append({"role": "user", "content": user_prompt})
for content in self.model.stream(
self.chat_history,
function_call='auto',
functions=[
schema_read_file,
schema_create_file,
schema_write_file,
schema_list_files,
schema_create_directory,
schema_search_arxiv,
schema_search_papers_with_code,
schema_web_search,
schema_execute_command,
schema_preview_csv_data
]
):
if content:
text += content
yield text

self.chat_history.append({"role": "assistant", "content": text})
42 changes: 10 additions & 32 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@
import uvicorn
import questionary
from pathlib import Path
from rich.live import Live
from rich.panel import Panel
from rich.console import Console
from rich.markdown import Markdown
from concurrent.futures import ThreadPoolExecutor

import mle
import mle.workflow as workflow
from mle.server import app
from mle.model import load_model
from mle.agents import CodeAgent
import mle.workflow as workflow
from mle.utils import Memory, WorkflowCache
from mle.utils import Memory
from mle.utils.system import (
get_config,
write_config,
Expand Down Expand Up @@ -58,6 +53,9 @@ def start(ctx, mode, model):
elif mode == 'kaggle':
# Kaggle mode
return ctx.invoke(kaggle, model=model)
elif mode == 'chat':
# Chat mode
return ctx.invoke(chat, model=model)
else:
raise ValueError("Invalid mode. Supported modes: 'baseline', 'report', 'kaggle'.")

Expand All @@ -79,6 +77,8 @@ def report(ctx, repo, model, user, visualize):
"[blue underline]http://localhost:3000/[/blue underline]",
console=console, title="MLE Report", color="green"
)
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor() as executor:
future1 = executor.submit(ctx.invoke, serve)
future2 = executor.submit(ctx.invoke, web)
Expand Down Expand Up @@ -139,37 +139,15 @@ def kaggle(model):


@cli.command()
def chat():
@click.option('--model', default=None, help='The model to use for the chat.')
def chat(model):
"""
chat: start an interactive chat with LLM to work on your ML project.
"""
if not check_config(console):
return

model = load_model(os.getcwd())
cache = WorkflowCache(os.getcwd())
coder = CodeAgent(model)

# read the project information
dataset = cache.resume_variable("dataset")
ml_requirement = cache.resume_variable("ml_requirement")
advisor_report = cache.resume_variable("advisor_report")

# inject the project information into prompts
coder.read_requirement(advisor_report or ml_requirement or dataset)

while True:
try:
user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask()
if user_pmpt:
with Live(console=Console()) as live:
for text in coder.chat(user_pmpt.strip()):
live.update(
Panel(Markdown(text), title="[bold magenta]MLE-Agent[/]", border_style="magenta"),
refresh=True
)
except (KeyboardInterrupt, EOFError):
exit()
return workflow.chat(os.getcwd(), model)


@cli.command()
Expand Down
21 changes: 14 additions & 7 deletions mle/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,18 @@ class WorkflowCache:
methods to load, store, and remove cached steps.
"""

def __init__(self, project_dir: str):
def __init__(self, project_dir: str, workflow: str = 'baseline'):
"""
Initialize WorkflowCache with a project directory.
Args:
project_dir (str): The directory of the project.
workflow (str): The name of the cached workflow.
"""
self.project_dir = project_dir
self.buffer = self._load_cache_buffer()
self.cache: Dict[int, Dict[str, Any]] = self.buffer["cache"]
self.workflow = workflow
self.buffer = self._load_cache_buffer(workflow)
self.cache: Dict[int, Dict[str, Any]] = self.buffer["cache"][workflow]

def is_empty(self) -> bool:
"""
Expand Down Expand Up @@ -124,22 +126,27 @@ def resume_variable(self, key: str, step: Optional[int] = None):
if step is not None:
return self.__call__(step).resume(key)
else:
for step in range(self.current_step()):
for step in range(self.current_step() + 1):
value = self.resume_variable(key, step)
if value is not None:
return value
return None

def _load_cache_buffer(self) -> Dict[str, Any]:
def _load_cache_buffer(self, workflow: str) -> Dict[str, Any]:
"""
Load the cache buffer from the configuration.
Args:
workflow (str): The name of the cached workflow.
Returns:
dict: The buffer loaded from the configuration.
"""
buffer = get_config() or {}
if "cache" not in buffer:
if "cache" not in buffer.keys():
buffer["cache"] = {}
if workflow not in buffer["cache"].keys():
buffer["cache"][workflow] = {}
return buffer

def _store_cache_buffer(self) -> None:
Expand All @@ -159,7 +166,7 @@ def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperat
Returns:
WorkflowCacheOperator: An instance of WorkflowCacheOperator.
"""
if step not in self.cache:
if step not in self.cache.keys():
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.cache[step] = {
"step": step,
Expand Down
3 changes: 2 additions & 1 deletion mle/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .baseline import baseline
from .report import report
from .kaggle import kaggle
from .kaggle import kaggle
from .chat import chat
2 changes: 1 addition & 1 deletion mle/workflow/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def baseline(work_dir: str, model=None):
"""

console = Console()
cache = WorkflowCache(work_dir)
cache = WorkflowCache(work_dir, 'baseline')
model = load_model(work_dir, model)

if not cache.is_empty():
Expand Down
41 changes: 41 additions & 0 deletions mle/workflow/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Chat Mode: the mode to have an interactive chat with LLM to work on ML project.
"""
import os
import questionary
from rich.live import Live
from rich.panel import Panel
from rich.console import Console
from rich.markdown import Markdown
from mle.model import load_model
from mle.utils import print_in_box, WorkflowCache
from mle.agents import ChatAgent


def chat(work_dir: str, model=None):
console = Console()
cache = WorkflowCache(work_dir, 'chat')
model = load_model(work_dir, model)
chatbot = ChatAgent(model)

if not cache.is_empty():
if questionary.confirm(f"Would you like to continue the previous conversation?\n").ask():
chatbot.chat_history = cache.resume_variable("conversation")

with cache(step=1, name="chat") as ca:
greets = chatbot.greet()
print_in_box(greets, console=console, title="MLE Chatbot", color="magenta")

while True:
try:
user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask()
if user_pmpt:
with Live(console=Console()) as live:
for text in chatbot.chat(user_pmpt.strip()):
live.update(
Panel(Markdown(text), title="[bold magenta]MLE-Agent[/]", border_style="magenta"),
refresh=True
)
ca.store("conversation", chatbot.chat_history)
except (KeyboardInterrupt, EOFError):
break
2 changes: 1 addition & 1 deletion mle/workflow/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def kaggle(work_dir: str, model=None, kaggle_username=None, kaggle_token=None):
The workflow of the kaggle mode.
"""
console = Console()
cache = WorkflowCache(work_dir)
cache = WorkflowCache(work_dir, 'kaggle')
model = load_model(work_dir, model)
kaggle = KaggleIntegration(kaggle_username, kaggle_token)

Expand Down

0 comments on commit 6f378d6

Please sign in to comment.