From 8c971801366ee5cdd412f0df86764a5ffb0a9627 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 23 Dec 2023 01:27:17 +0800 Subject: [PATCH 1/5] feat: New LLMClient --- dbgpt/core/__init__.py | 15 +- dbgpt/core/awel/operator/base.py | 4 +- dbgpt/core/awel/trigger/http_trigger.py | 3 +- dbgpt/core/interface/llm.py | 230 +++++++++----- dbgpt/model/__init__.py | 4 + dbgpt/model/cluster/base.py | 9 + dbgpt/model/cluster/client.py | 40 +++ dbgpt/model/cluster/manager_base.py | 40 ++- dbgpt/model/cluster/tests/conftest.py | 10 +- dbgpt/model/cluster/worker/default_worker.py | 39 ++- .../model/cluster/worker/embedding_worker.py | 9 + dbgpt/model/cluster/worker/manager.py | 82 ++++- dbgpt/model/cluster/worker/remote_manager.py | 47 ++- dbgpt/model/cluster/worker/remote_worker.py | 40 ++- dbgpt/model/cluster/worker_base.py | 38 ++- dbgpt/model/proxy/llms/chatgpt.py | 2 +- dbgpt/model/utils/__init__.py | 0 dbgpt/model/utils/chatgpt_utils.py | 282 ++++++++++++++++++ examples/awel/simple_llm_client_example.py | 158 ++++++++++ examples/sdk/simple_sdk_llm_example.py | 10 +- examples/sdk/simple_sdk_llm_sql_example.py | 10 +- 21 files changed, 964 insertions(+), 108 deletions(-) create mode 100644 dbgpt/model/cluster/client.py create mode 100644 dbgpt/model/utils/__init__.py create mode 100644 dbgpt/model/utils/chatgpt_utils.py create mode 100644 examples/awel/simple_llm_client_example.py diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index c9996d3cc..e740b1294 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -1,9 +1,12 @@ from dbgpt.core.interface.llm import ( ModelInferenceMetrics, + ModelRequest, ModelOutput, - OpenAILLM, - BaseLLMOperator, + LLMClient, + LLMOperator, + StreamingLLMOperator, RequestBuildOperator, + ModelMetadata, ) from dbgpt.core.interface.message import ( ModelMessage, @@ -37,11 +40,15 @@ __ALL__ = [ "ModelInferenceMetrics", + "ModelRequest", "ModelOutput", - "OpenAILLM", - "BaseLLMOperator", + "Operator", "RequestBuildOperator", + "ModelMetadata", "ModelMessage", + "LLMClient", + "LLMOperator", + "StreamingLLMOperator", "ModelMessageRoleType", "OnceConversation", "StorageConversation", diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index 2e7f06bcf..c114412b0 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -211,7 +211,9 @@ async def call_stream( Returns: AsyncIterator[OUT]: An asynchronous iterator over the output stream. """ - out_ctx = await self._runner.execute_workflow(self, call_data) + out_ctx = await self._runner.execute_workflow( + self, call_data, streaming_call=True + ) return out_ctx.current_task_context.task_output.output_stream def _blocking_call_stream( diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 5fb2228d2..58f0ba529 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -130,8 +130,9 @@ async def _trigger_dag( "Connection": "keep-alive", "Transfer-Encoding": "chunked", } + generator = await end_node.call_stream(call_data={"data": body}) return StreamingResponse( - end_node.call_stream(call_data={"data": body}), + generator, headers=headers, media_type=media_type, ) diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index c99b0c8db..49fe84aa7 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -1,10 +1,10 @@ -from abc import ABC +from abc import ABC, abstractmethod from typing import Optional, Dict, List, Any, Union, AsyncIterator - import time -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, field import copy +from dbgpt.util import BaseParameters from dbgpt.util.annotations import PublicAPI from dbgpt.util.model_utils import GPUInfo from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType @@ -12,6 +12,7 @@ @dataclass +@PublicAPI(stability="beta") class ModelInferenceMetrics: """A class to represent metrics for assessing the inference performance of a LLM.""" @@ -97,6 +98,7 @@ def to_dict(self) -> Dict: @dataclass +@PublicAPI(stability="beta") class ModelOutput: """A class to represent the output of a LLM.""" "" @@ -118,6 +120,7 @@ def to_dict(self) -> Dict: @dataclass +@PublicAPI(stability="beta") class ModelRequest: model: str """The name of the model.""" @@ -142,7 +145,7 @@ class ModelRequest: span_id: Optional[str] = None """The span id of the model inference.""" - def to_dict(self) -> Dict: + def to_dict(self) -> Dict[str, Any]: new_reqeust = copy.deepcopy(self) new_reqeust.messages = list( map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages) @@ -166,6 +169,110 @@ def _build(model: str, prompt: str, **kwargs): **kwargs, ) + def to_openai_messages(self) -> List[Dict[str, Any]]: + """Convert the messages to the format of OpenAI API. + + This function will move last user message to the end of the list. + + Returns: + List[Dict[str, Any]]: The messages in the format of OpenAI API. + + Examples: + .. code-block:: python + messages = [ + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"), + ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.") + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"), + ] + openai_messages = ModelRequest.to_openai_messages(messages) + assert openai_messages == [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hi, I'm a robot."}, + {"role": "user", "content": "Who are your"}, + ] + """ + messages = [ + m if isinstance(m, ModelMessage) else ModelMessage(**m) + for m in self.messages + ] + return ModelMessage.to_openai_messages(messages) + + +@dataclass +@PublicAPI(stability="beta") +class ModelMetadata(BaseParameters): + """A class to represent a LLM model.""" + + model: str = field( + metadata={"help": "Model name"}, + ) + context_length: Optional[int] = field( + default=4096, + metadata={"help": "Context length of model"}, + ) + chat_model: Optional[bool] = field( + default=True, + metadata={"help": "Whether the model is a chat model"}, + ) + is_function_calling_model: Optional[bool] = field( + default=False, + metadata={"help": "Whether the model is a function calling model"}, + ) + metadata: Optional[Dict[str, Any]] = field( + default_factory=dict, + metadata={"help": "Model metadata"}, + ) + + +@PublicAPI(stability="beta") +class LLMClient(ABC): + """An abstract class for LLM client.""" + + @abstractmethod + async def generate(self, request: ModelRequest) -> ModelOutput: + """Generate a response for a given model request. + + Args: + request(ModelRequest): The model request. + + Returns: + ModelOutput: The model output. + + """ + + @abstractmethod + async def generate_stream( + self, request: ModelRequest + ) -> AsyncIterator[ModelOutput]: + """Generate a stream of responses for a given model request. + + Args: + request(ModelRequest): The model request. + + Returns: + AsyncIterator[ModelOutput]: The model output stream. + """ + + @abstractmethod + async def models(self) -> List[ModelMetadata]: + """Get all the models. + + Returns: + List[ModelMetadata]: A list of model metadata. + """ + + @abstractmethod + async def count_token(self, model: str, prompt: str) -> int: + """Count the number of tokens in a given prompt. + + Args: + model(str): The model name. + prompt(str): The prompt. + + Returns: + int: The number of tokens. + """ + class RequestBuildOperator(MapOperator[str, ModelRequest], ABC): def __init__(self, model: str, **kwargs): @@ -176,85 +283,52 @@ async def map(self, input_value: str) -> ModelRequest: return ModelRequest._build(self._model, input_value) -class BaseLLMOperator( - MapOperator[ModelRequest, ModelOutput], - StreamifyAbsOperator[ModelRequest, ModelOutput], - ABC, -): +class BaseLLM: """The abstract operator for a LLM.""" + def __init__(self, llm_client: Optional[LLMClient] = None): + self._llm_client = llm_client -@PublicAPI(stability="beta") -class OpenAILLM(BaseLLMOperator): - """The operator for OpenAI LLM. + @property + def llm_client(self) -> LLMClient: + """Return the LLM client.""" + if not self._llm_client: + raise ValueError("llm_client is not set") + return self._llm_client + + +class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): + """The operator for a LLM. - Examples: + Args: + llm_client (LLMClient, optional): The LLM client. Defaults to None. - .. code-block:: python - llm = OpenAILLM() - model_request = ModelRequest(model="gpt-3.5-turbo", messages=[ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")]) - model_output = await llm.map(model_request) + This operator will generate a no streaming response. """ - def __int__(self): - try: - import openai - except ImportError as e: - raise ImportError("Please install openai package to use OpenAILLM") from e - import importlib.metadata as metadata - - if not metadata.version("openai") >= "1.0.0": - raise ImportError("Please upgrade openai package to version 1.0.0 or above") - - async def _send_request( - self, model_request: ModelRequest, stream: Optional[bool] = False - ): - import os - from openai import AsyncOpenAI - - client = AsyncOpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), - base_url=os.environ.get("OPENAI_API_BASE"), - ) - messages = ModelMessage.to_openai_messages(model_request._get_messages()) - payloads = { - "model": model_request.model, - "stream": stream, - } - if model_request.temperature is not None: - payloads["temperature"] = model_request.temperature - if model_request.max_new_tokens: - payloads["max_tokens"] = model_request.max_new_tokens - - return await client.chat.completions.create(messages=messages, **payloads) - - async def map(self, model_request: ModelRequest) -> ModelOutput: - try: - chat_completion = await self._send_request(model_request, stream=False) - text = chat_completion.choices[0].message.content - usage = chat_completion.usage.dict() - return ModelOutput(text=text, error_code=0, usage=usage) - except Exception as e: - return ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=1, - ) + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client=llm_client) + MapOperator.__init__(self, **kwargs) - async def streamify( - self, model_request: ModelRequest - ) -> AsyncIterator[ModelOutput]: - try: - chat_completion = await self._send_request(model_request, stream=True) - text = "" - for r in chat_completion: - if len(r.choices) == 0: - continue - if r.choices[0].delta.content is not None: - content = r.choices[0].delta.content - text += content - yield ModelOutput(text=text, error_code=0) - except Exception as e: - yield ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=1, - ) + async def map(self, request: ModelRequest) -> ModelOutput: + return await self.llm_client.generate(request) + + +class StreamingLLMOperator( + BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC +): + """The streaming operator for a LLM. + + Args: + llm_client (LLMClient, optional): The LLM client. Defaults to None. + + This operator will generate streaming response. + """ + + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client=llm_client) + StreamifyAbsOperator.__init__(self, **kwargs) + + async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]: + async for output in self.llm_client.generate_stream(request): + yield output diff --git a/dbgpt/model/__init__.py b/dbgpt/model/__init__.py index e69de29bb..4e317eb06 100644 --- a/dbgpt/model/__init__.py +++ b/dbgpt/model/__init__.py @@ -0,0 +1,4 @@ +from dbgpt.model.cluster.client import DefaultLLMClient +from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient + +__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"] diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py index 954f5ef97..cb3f34732 100644 --- a/dbgpt/model/cluster/base.py +++ b/dbgpt/model/cluster/base.py @@ -30,6 +30,15 @@ class EmbeddingsRequest(BaseModel): span_id: str = None +class CountTokenRequest(BaseModel): + model: str + prompt: str + + +class ModelMetadataRequest(BaseModel): + model: str + + class WorkerApplyRequest(BaseModel): model: str apply_type: WorkerApplyType diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py new file mode 100644 index 000000000..10b1cfb7d --- /dev/null +++ b/dbgpt/model/cluster/client.py @@ -0,0 +1,40 @@ +from typing import AsyncIterator, List +import asyncio +from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata +from dbgpt.model.parameter import WorkerType +from dbgpt.model.cluster.manager_base import WorkerManager + + +class DefaultLLMClient(LLMClient): + def __init__(self, worker_manager: WorkerManager): + self._worker_manager = worker_manager + + async def generate(self, request: ModelRequest) -> ModelOutput: + return await self._worker_manager.generate(request.to_dict()) + + async def generate_stream( + self, request: ModelRequest + ) -> AsyncIterator[ModelOutput]: + async for output in self._worker_manager.generate_stream(request.to_dict()): + yield output + + async def models(self) -> List[ModelMetadata]: + instances = await self._worker_manager.get_all_model_instances( + WorkerType.LLM.value, healthy_only=True + ) + query_metadata_task = [] + for instance in instances: + worker_name, _ = WorkerType.parse_worker_key(instance.worker_key) + query_metadata_task.append( + self._worker_manager.get_model_metadata({"model": worker_name}) + ) + models: List[ModelMetadata] = await asyncio.gather(*query_metadata_task) + model_map = {} + for single_model in models: + model_map[single_model.model] = single_model + return [model_map[model_name] for model_name in sorted(model_map.keys())] + + async def count_token(self, model: str, prompt: str) -> int: + return await self._worker_manager.count_token( + {"model": model, "prompt": prompt} + ) diff --git a/dbgpt/model/cluster/manager_base.py b/dbgpt/model/cluster/manager_base.py index c4beb8ce7..636b2822f 100644 --- a/dbgpt/model/cluster/manager_base.py +++ b/dbgpt/model/cluster/manager_base.py @@ -5,7 +5,7 @@ from datetime import datetime from concurrent.futures import Future from dbgpt.component import BaseComponent, ComponentType, SystemApp -from dbgpt.core import ModelOutput +from dbgpt.core import ModelOutput, ModelMetadata from dbgpt.model.base import WorkerSupportedModel, WorkerApplyOutput from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest @@ -38,6 +38,11 @@ def _to_print_key(self): port = self.port return f"model {model_name}@{model_type}({host}:{port})" + @property + def stopped(self): + """Check if the worker is stopped""" "" + return self.stop_event.is_set() + class WorkerManager(ABC): @abstractmethod @@ -62,6 +67,20 @@ async def get_model_instances( ) -> List[WorkerRunData]: """Asynchronous get model instances by worker type and model name""" + @abstractmethod + async def get_all_model_instances( + self, worker_type: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + """Asynchronous get all model instances + + Args: + worker_type (str): worker type + healthy_only (bool, optional): only return healthy instances. Defaults to True. + + Returns: + List[WorkerRunData]: worker run data list + """ + @abstractmethod def sync_get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True @@ -112,6 +131,25 @@ def sync_embeddings(self, params: Dict) -> List[List[float]]: We must provide a synchronous version. """ + @abstractmethod + async def count_token(self, params: Dict) -> int: + """Count token of prompt + + Args: + params (Dict): parameters, eg. {"prompt": "hello", "model": "vicuna-13b-v1.5"} + + Returns: + int: token count + """ + + @abstractmethod + async def get_model_metadata(self, params: Dict) -> ModelMetadata: + """Get model metadata + + Args: + params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"} + """ + @abstractmethod async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: """Worker apply""" diff --git a/dbgpt/model/cluster/tests/conftest.py b/dbgpt/model/cluster/tests/conftest.py index 49cdbc0a1..eeba4826f 100644 --- a/dbgpt/model/cluster/tests/conftest.py +++ b/dbgpt/model/cluster/tests/conftest.py @@ -3,7 +3,7 @@ from contextlib import contextmanager, asynccontextmanager from typing import List, Iterator, Dict, Tuple from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from dbgpt.core import ModelOutput +from dbgpt.core import ModelOutput, ModelMetadata from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.cluster.worker.manager import ( WorkerManager, @@ -80,6 +80,14 @@ def generate(self, params: Dict) -> ModelOutput: output = out return output + def count_token(self, prompt: str) -> int: + return len(prompt) + + def get_model_metadata(self, params: Dict) -> ModelMetadata: + return ModelMetadata( + model=self.model_parameters.model_name, + ) + def embeddings(self, params: Dict) -> List[List[float]]: return self._embeddings diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index c4967a076..a42858bb4 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -8,7 +8,7 @@ from dbgpt.configs.model_config import get_device from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.model_adapter import get_llm_model_adapter -from dbgpt.core import ModelOutput, ModelInferenceMetrics +from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata from dbgpt.model.loader import ModelLoader, _get_model_real_path from dbgpt.model.parameter import ModelParameters from dbgpt.model.cluster.worker_base import ModelWorker @@ -118,6 +118,8 @@ def start( f"Parse model max length {model_max_length} from model {self.model_name}." ) self.context_len = model_max_length + elif hasattr(model_params, "max_context_size"): + self.context_len = model_params.max_context_size def stop(self) -> None: if not self.model: @@ -186,6 +188,22 @@ def generate(self, params: Dict) -> ModelOutput: output = out return output + def count_token(self, prompt: str) -> int: + return _try_to_count_token(prompt, self.tokenizer) + + async def async_count_token(self, prompt: str) -> int: + # TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async + raise NotImplementedError + + def get_model_metadata(self, params: Dict) -> ModelMetadata: + return ModelMetadata( + model=self.model_name, + context_length=self.context_len, + ) + + async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: + return self.get_model_metadata(params) + def embeddings(self, params: Dict) -> List[List[float]]: raise NotImplementedError @@ -436,6 +454,25 @@ def _new_metrics_from_model_output( return metrics +def _try_to_count_token(prompt: str, tokenizer) -> int: + """Try to count token of prompt + + Args: + prompt (str): prompt + tokenizer ([type]): tokenizer + + Returns: + int: token count, if error return -1 + + TODO: More implementation + """ + try: + return len(tokenizer(prompt).input_ids[0]) + except Exception as e: + logger.warning(f"Count token error, detail: {e}, return -1") + return -1 + + def _try_import_torch(): global torch global _torch_imported diff --git a/dbgpt/model/cluster/worker/embedding_worker.py b/dbgpt/model/cluster/worker/embedding_worker.py index f91a7a260..c8df5d779 100644 --- a/dbgpt/model/cluster/worker/embedding_worker.py +++ b/dbgpt/model/cluster/worker/embedding_worker.py @@ -2,6 +2,7 @@ from typing import Dict, List, Type, Optional from dbgpt.configs.model_config import get_device +from dbgpt.core import ModelMetadata from dbgpt.model.loader import _get_model_real_path from dbgpt.model.parameter import ( EmbeddingModelParameters, @@ -89,6 +90,14 @@ def generate(self, params: Dict): """Generate non stream result""" raise NotImplementedError("Not supported generate for embeddings model") + def count_token(self, prompt: str) -> int: + raise NotImplementedError("Not supported count_token for embeddings model") + + def get_model_metadata(self, params: Dict) -> ModelMetadata: + raise NotImplementedError( + "Not supported get_model_metadata for embeddings model" + ) + def embeddings(self, params: Dict) -> List[List[float]]: model = params.get("model") logger.info(f"Receive embeddings request, model: {model}") diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index d341824b9..83b39370e 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -15,7 +15,7 @@ from dbgpt.component import SystemApp from dbgpt.configs.model_config import LOGDIR -from dbgpt.core import ModelOutput +from dbgpt.core import ModelOutput, ModelMetadata from dbgpt.model.base import ( ModelInstance, WorkerApplyOutput, @@ -271,6 +271,18 @@ async def get_model_instances( ) -> List[WorkerRunData]: return self.sync_get_model_instances(worker_type, model_name, healthy_only) + async def get_all_model_instances( + self, worker_type: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + instances = list(itertools.chain(*self.workers.values())) + result = [] + for instance in instances: + name, wt = WorkerType.parse_worker_key(instance.worker_key) + if wt != worker_type or (healthy_only and instance.stopped): + continue + result.append(instance) + return result + def sync_get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: @@ -390,6 +402,43 @@ def sync_embeddings(self, params: Dict) -> List[List[float]]: worker_run_data = self._sync_get_model(params, worker_type="text2vec") return worker_run_data.worker.embeddings(params) + async def count_token(self, params: Dict) -> int: + """Count token of prompt""" + with root_tracer.start_span( + "WorkerManager.count_token", params.get("span_id") + ) as span: + params["span_id"] = span.span_id + try: + worker_run_data = await self._get_model(params) + except Exception as e: + raise e + prompt = params.get("prompt") + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + return await worker_run_data.worker.async_count_token(prompt) + else: + return await self.run_blocking_func( + worker_run_data.worker.count_token, prompt + ) + + async def get_model_metadata(self, params: Dict) -> ModelMetadata: + """Get model metadata""" + with root_tracer.start_span( + "WorkerManager.get_model_metadata", params.get("span_id") + ) as span: + params["span_id"] = span.span_id + try: + worker_run_data = await self._get_model(params) + except Exception as e: + raise e + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + return await worker_run_data.worker.async_get_model_metadata(params) + else: + return await self.run_blocking_func( + worker_run_data.worker.get_model_metadata, params + ) + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: apply_func: Callable[[WorkerApplyRequest], Awaitable[str]] = None if apply_req.apply_type == WorkerApplyType.START: @@ -601,6 +650,13 @@ async def get_model_instances( worker_type, model_name, healthy_only ) + async def get_all_model_instances( + self, worker_type: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + return await self.worker_manager.get_all_model_instances( + worker_type, healthy_only + ) + def sync_get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: @@ -635,6 +691,12 @@ async def embeddings(self, params: Dict) -> List[List[float]]: def sync_embeddings(self, params: Dict) -> List[List[float]]: return self.worker_manager.sync_embeddings(params) + async def count_token(self, params: Dict) -> int: + return await self.worker_manager.count_token(params) + + async def get_model_metadata(self, params: Dict) -> ModelMetadata: + return await self.worker_manager.get_model_metadata(params) + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: return await self.worker_manager.worker_apply(apply_req) @@ -696,6 +758,24 @@ async def api_embeddings(request: EmbeddingsRequest): return await worker_manager.embeddings(params) +@router.post("/worker/count_token") +async def api_count_token(request: CountTokenRequest): + params = request.dict(exclude_none=True) + span_id = root_tracer.get_current_span_id() + if "span_id" not in params and span_id: + params["span_id"] = span_id + return await worker_manager.count_token(params) + + +@router.post("/worker/model_metadata") +async def api_get_model_metadata(request: ModelMetadataRequest): + params = request.dict(exclude_none=True) + span_id = root_tracer.get_current_span_id() + if "span_id" not in params and span_id: + params["span_id"] = span_id + return await worker_manager.get_model_metadata(params) + + @router.post("/worker/apply") async def api_worker_apply(request: WorkerApplyRequest): return await worker_manager.worker_apply(request) diff --git a/dbgpt/model/cluster/worker/remote_manager.py b/dbgpt/model/cluster/worker/remote_manager.py index 5561c026d..aa47d572d 100644 --- a/dbgpt/model/cluster/worker/remote_manager.py +++ b/dbgpt/model/cluster/worker/remote_manager.py @@ -133,22 +133,29 @@ def _build_worker_instances( self, model_name: str, instances: List[ModelInstance] ) -> List[WorkerRunData]: worker_instances = [] - for ins in instances: - worker = RemoteModelWorker() - worker.load_worker(model_name, model_name, host=ins.host, port=ins.port) - wr = WorkerRunData( - host=ins.host, - port=ins.port, - worker_key=ins.model_name, - worker=worker, - worker_params=None, - model_params=None, - stop_event=asyncio.Event(), - semaphore=asyncio.Semaphore(100), # Not limit in client + for instance in instances: + worker_instances.append( + self._build_single_worker_instance(model_name, instance) ) - worker_instances.append(wr) return worker_instances + def _build_single_worker_instance(self, model_name: str, instance: ModelInstance): + worker = RemoteModelWorker() + worker.load_worker( + model_name, model_name, host=instance.host, port=instance.port + ) + wr = WorkerRunData( + host=instance.host, + port=instance.port, + worker_key=instance.model_name, + worker=worker, + worker_params=None, + model_params=None, + stop_event=asyncio.Event(), + semaphore=asyncio.Semaphore(100), # Not limit in client + ) + return wr + async def get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: @@ -158,6 +165,20 @@ async def get_model_instances( ) return self._build_worker_instances(model_name, instances) + async def get_all_model_instances( + self, worker_type: str, healthy_only: bool = True + ) -> List[WorkerRunData]: + instances: List[ + ModelInstance + ] = await self.model_registry.get_all_model_instances(healthy_only=healthy_only) + result = [] + for instance in instances: + name, wt = WorkerType.parse_worker_key(instance.model_name) + if wt != worker_type: + continue + result.append(self._build_single_worker_instance(name, instance)) + return result + def sync_get_model_instances( self, worker_type: str, model_name: str, healthy_only: bool = True ) -> List[WorkerRunData]: diff --git a/dbgpt/model/cluster/worker/remote_worker.py b/dbgpt/model/cluster/worker/remote_worker.py index ccf153d02..895d998d3 100644 --- a/dbgpt/model/cluster/worker/remote_worker.py +++ b/dbgpt/model/cluster/worker/remote_worker.py @@ -1,7 +1,7 @@ import json from typing import Dict, Iterator, List import logging -from dbgpt.core import ModelOutput +from dbgpt.core import ModelOutput, ModelMetadata from dbgpt.model.parameter import ModelParameters from dbgpt.model.cluster.worker_base import ModelWorker @@ -90,6 +90,44 @@ async def async_generate(self, params: Dict) -> ModelOutput: ) return ModelOutput(**response.json()) + def count_token(self, prompt: str) -> int: + raise NotImplementedError + + async def async_count_token(self, prompt: str) -> int: + import httpx + + async with httpx.AsyncClient() as client: + url = self.worker_addr + "/count_token" + logger.debug(f"Send async_count_token to url {url}, params: {prompt}") + response = await client.post( + url, + headers=self.headers, + json={"prompt": prompt}, + timeout=self.timeout, + ) + return response.json() + + async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: + """Asynchronously get model metadata""" + import httpx + + async with httpx.AsyncClient() as client: + url = self.worker_addr + "/model_metadata" + logger.debug( + f"Send async_get_model_metadata to url {url}, params: {params}" + ) + response = await client.post( + url, + headers=self.headers, + json=params, + timeout=self.timeout, + ) + return ModelMetadata(**response.json()) + + def get_model_metadata(self, params: Dict) -> ModelMetadata: + """Get model metadata""" + raise NotImplementedError + def embeddings(self, params: Dict) -> List[List[float]]: """Get embeddings for input""" import requests diff --git a/dbgpt/model/cluster/worker_base.py b/dbgpt/model/cluster/worker_base.py index b89edfa15..1f0005a8f 100644 --- a/dbgpt/model/cluster/worker_base.py +++ b/dbgpt/model/cluster/worker_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Iterator, List, Type -from dbgpt.core import ModelOutput +from dbgpt.core import ModelOutput, ModelMetadata from dbgpt.model.parameter import ModelParameters, WorkerType from dbgpt.util.parameter_utils import ( ParameterDescription, @@ -92,6 +92,42 @@ async def async_generate(self, params: Dict) -> ModelOutput: """Asynchronously generate output (non-stream) based on provided parameters.""" raise NotImplementedError + @abstractmethod + def count_token(self, prompt: str) -> int: + """Count token of prompt + Args: + prompt (str): prompt + + Returns: + int: token count + """ + + async def async_count_token(self, prompt: str) -> int: + """Asynchronously count token of prompt + Args: + prompt (str): prompt + + Returns: + int: token count + """ + raise NotImplementedError + + @abstractmethod + def get_model_metadata(self, params: Dict) -> ModelMetadata: + """Get model metadata + + Args: + params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"} + """ + + async def async_get_model_metadata(self, params: Dict) -> ModelMetadata: + """Asynchronously get model metadata + + Args: + params (Dict): parameters, eg. {"model": "vicuna-13b-v1.5"} + """ + raise NotImplementedError + @abstractmethod def embeddings(self, params: Dict) -> List[List[float]]: """ diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index d81626e7a..54b2ae2c9 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -70,7 +70,7 @@ def _initialize_openai_v1(params: ProxyModelParameters): api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai") base_url = params.proxy_api_base or os.getenv( - "OPENAI_API_TYPE", + "OPENAI_API_BASE", os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, ) api_key = params.proxy_api_key or os.getenv( diff --git a/dbgpt/model/utils/__init__.py b/dbgpt/model/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py new file mode 100644 index 000000000..f3aea5a07 --- /dev/null +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import os +import logging +from dataclasses import dataclass +import importlib.metadata as metadata +from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator + +from dbgpt.core.interface.llm import ModelMetadata, LLMClient +from dbgpt.core.interface.llm import ModelOutput, ModelRequest + +if TYPE_CHECKING: + import httpx + from httpx._types import ProxiesTypes + from openai import AsyncAzureOpenAI + from openai import AsyncOpenAI + + ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI] + +logger = logging.getLogger(__name__) + + +@dataclass +class OpenAIParameters: + """A class to represent a LLM model.""" + + api_type: str = "open_ai" + api_base: Optional[str] = None + api_key: Optional[str] = None + api_version: Optional[str] = None + full_url: Optional[str] = None + proxies: Optional["ProxiesTypes"] = None + + +def _initialize_openai_v1(init_params: OpenAIParameters): + try: + from openai import OpenAI + except ImportError as exc: + raise ValueError( + "Could not import python package: openai " + "Please install openai by command `pip install openai" + ) from exc + + if not metadata.version("openai") >= "1.0.0": + raise ImportError("Please upgrade openai package to version 1.0.0 or above") + + api_type: Optional[str] = init_params.api_type + api_base: Optional[str] = init_params.api_base + api_key: Optional[str] = init_params.api_key + api_version: Optional[str] = init_params.api_version + full_url: Optional[str] = init_params.full_url + + api_type = api_type or os.getenv("OPENAI_API_TYPE", "open_ai") + + base_url = api_base or os.getenv( + "OPENAI_API_BASE", + os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, + ) + api_key = api_key or os.getenv( + "OPENAI_API_KEY", + os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, + ) + api_version = api_version or os.getenv("OPENAI_API_VERSION") + + if not base_url and full_url: + base_url = full_url.split("/chat/completions")[0] + + if api_key is None: + raise ValueError("api_key is required, please set OPENAI_API_KEY environment") + if base_url is None: + raise ValueError("base_url is required, please set OPENAI_BASE_URL environment") + if base_url.endswith("/"): + base_url = base_url[:-1] + + openai_params = { + "api_key": api_key, + "base_url": base_url, + } + return openai_params, api_type, api_version + + +def _build_openai_client(init_params: OpenAIParameters): + import httpx + + openai_params, api_type, api_version = _initialize_openai_v1(init_params) + if api_type == "azure": + from openai import AsyncAzureOpenAI + + return AsyncAzureOpenAI( + api_key=openai_params["api_key"], + api_version=api_version, + azure_endpoint=openai_params["base_url"], + http_client=httpx.AsyncClient(proxies=init_params.proxies), + ) + else: + from openai import AsyncOpenAI + + return AsyncOpenAI( + **openai_params, http_client=httpx.AsyncClient(proxies=init_params.proxies) + ) + + +class OpenAILLMClient(LLMClient): + """An implementation of LLMClient using OpenAI API. + + In order to have as few dependencies as possible, we directly use the http API. + """ + + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[str] = None, + api_version: Optional[str] = None, + model: Optional[str] = "gpt-3.5-turbo", + proxies: Optional["ProxiesTypes"] = None, + timeout: Optional[int] = 240, + model_alias: Optional[str] = "chatgpt_proxyllm", + context_length: Optional[int] = 8192, + openai_client: Optional["ClientType"] = None, + openai_kwargs: Optional[Dict[str, Any]] = None, + ): + self._init_params = OpenAIParameters( + api_type=api_type, + api_base=api_base, + api_key=api_key, + api_version=api_version, + proxies=proxies, + ) + + self._model = model + self._proxies = proxies + self._timeout = timeout + self._model_alias = model_alias + self._context_length = context_length + self._client = openai_client + self._openai_kwargs = openai_kwargs or {} + + @property + def client(self) -> ClientType: + if self._client is None: + self._client = _build_openai_client(init_params=self._init_params) + return self._client + + def _build_request( + self, request: ModelRequest, stream: Optional[bool] = False + ) -> Dict[str, Any]: + payload = {"model": request.model or self._model, "stream": stream} + + # Apply openai kwargs + for k, v in self._openai_kwargs.items(): + payload[k] = v + if request.temperature: + payload["temperature"] = request.temperature + if request.max_new_tokens: + payload["max_tokens"] = request.max_new_tokens + return payload + + async def generate(self, request: ModelRequest) -> ModelOutput: + messages = request.to_openai_messages() + payload = self._build_request(request) + try: + chat_completion = await self.client.chat.completions.create( + messages=messages, **payload + ) + text = chat_completion.choices[0].message.content + usage = chat_completion.usage.dict() + return ModelOutput(text=text, error_code=0, usage=usage) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) + + async def generate_stream( + self, request: ModelRequest + ) -> AsyncIterator[ModelOutput]: + messages = request.to_openai_messages() + payload = self._build_request(request) + try: + chat_completion = await self.client.chat.completions.create( + messages=messages, **payload + ) + text = "" + for r in chat_completion: + if len(r.choices) == 0: + continue + if r.choices[0].delta.content is not None: + content = r.choices[0].delta.content + text += content + yield ModelOutput(text=text, error_code=0) + except Exception as e: + yield ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=1, + ) + + async def models(self) -> List[ModelMetadata]: + model_metadata = ModelMetadata( + model=self._model_alias, + context_length=await self.get_context_length(), + ) + return [model_metadata] + + async def get_context_length(self) -> int: + """Get the context length of the model. + + Returns: + int: The context length. + # TODO: This is a temporary solution. We should have a better way to get the context length. + eg. get real context length from the openai api. + """ + return self._context_length + + async def count_token(self, model: str, prompt: str) -> int: + """Count the number of tokens in a given prompt. + + TODO: Get the real number of tokens from the openai api or tiktoken package + """ + + raise NotImplementedError() + + +async def _to_openai_stream( + model: str, output_iter: AsyncIterator[ModelOutput] +) -> AsyncIterator[str]: + """Convert the output_iter to openai stream format. + + Args: + model (str): The model name. + output_iter (AsyncIterator[ModelOutput]): The output iterator. + """ + import json + import shortuuid + from fastchat.protocol.openai_api_protocol import ( + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + DeltaMessage, + ) + + id = f"chatcmpl-{shortuuid.random()}" + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + finish_stream_events = [] + async for model_output in output_iter: + model_output: ModelOutput = model_output + if model_output.error_code != 0: + yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = model_output.text.replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=delta_text), + finish_reason=model_output.finish_reason, + ) + chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model) + if delta_text is None: + if model_output.finish_reason is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" diff --git a/examples/awel/simple_llm_client_example.py b/examples/awel/simple_llm_client_example.py new file mode 100644 index 000000000..95be4d76e --- /dev/null +++ b/examples/awel/simple_llm_client_example.py @@ -0,0 +1,158 @@ +"""AWEL: Simple llm client example + + DB-GPT will automatically load and execute the current file after startup. + + Example: + + .. code-block:: shell + + DBGPT_SERVER="http://127.0.0.1:5000" + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "messages": "hello" + }' + + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate_stream \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "messages": "hello", + "stream": true + }' + + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "messages": "hello" + }' + +""" +from typing import Dict, Any, AsyncIterator, Optional, Union, List +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.component import ComponentType +from dbgpt.core.awel import DAG, HttpTrigger, MapOperator, TransformStreamAbsOperator +from dbgpt.core import ( + ModelMessage, + LLMClient, + LLMOperator, + StreamingLLMOperator, + ModelOutput, + ModelRequest, +) +from dbgpt.model import DefaultLLMClient +from dbgpt.model.cluster import WorkerManagerFactory + + +class TriggerReqBody(BaseModel): + messages: Union[str, List[Dict[str, str]]] = Field( + ..., description="User input messages" + ) + model: str = Field(..., description="Model name") + stream: Optional[bool] = Field(default=False, description="Whether return stream") + + +class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, input_value: TriggerReqBody) -> ModelRequest: + messages = [ModelMessage.build_human_message(input_value.messages)] + await self.current_dag_context.save_to_share_data( + "request_model_name", input_value.model + ) + return ModelRequest( + model=input_value.model, + messages=messages, + echo=False, + ) + + +class LLMMixin: + @property + def llm_client(self) -> LLMClient: + if not self._llm_client: + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + self._llm_client = DefaultLLMClient(worker_manager) + return self._llm_client + + +class MyLLMOperator(LLMMixin, LLMOperator): + def __init__(self, llm_client: LLMClient = None, **kwargs): + super().__init__(llm_client, **kwargs) + + +class MyStreamingLLMOperator(LLMMixin, StreamingLLMOperator): + def __init__(self, llm_client: LLMClient = None, **kwargs): + super().__init__(llm_client, **kwargs) + + +class MyLLMStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): + async def transform_stream( + self, input_value: AsyncIterator[ModelOutput] + ) -> AsyncIterator[str]: + from dbgpt.model.utils.chatgpt_utils import _to_openai_stream + + model = await self.current_dag_context.get_share_data("request_model_name") + async for output in _to_openai_stream(model, input_value): + yield output + + +class MyModelToolOperator(MapOperator[TriggerReqBody, Dict[str, Any]]): + def __init__(self, llm_client: LLMClient = None, **kwargs): + super().__init__(**kwargs) + self._llm_client = llm_client + + @property + def llm_client(self) -> LLMClient: + if not self._llm_client: + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + self._llm_client = DefaultLLMClient(worker_manager) + return self._llm_client + + async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: + prompt_tokens = await self.llm_client.count_token( + input_value.model, input_value.messages + ) + available_models = await self.llm_client.models() + return { + "prompt_tokens": prompt_tokens, + "available_models": available_models, + } + + +with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag: + # Receive http request and trigger dag to run. + trigger = HttpTrigger( + "/examples/simple_client/generate", methods="POST", request_body=TriggerReqBody + ) + request_handle_task = RequestHandleOperator() + model_task = MyLLMOperator() + model_parse_task = MapOperator(lambda out: out.to_dict()) + trigger >> request_handle_task >> model_task >> model_parse_task + +with DAG("dbgpt_awel_simple_llm_client_generate_stream") as client_generate_stream_dag: + # Receive http request and trigger dag to run. + trigger = HttpTrigger( + "/examples/simple_client/generate_stream", + methods="POST", + request_body=TriggerReqBody, + streaming_response=True, + ) + request_handle_task = RequestHandleOperator() + model_task = MyStreamingLLMOperator() + openai_format_stream_task = MyLLMStreamingOperator() + trigger >> request_handle_task >> model_task >> openai_format_stream_task + +with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag: + # Receive http request and trigger dag to run. + trigger = HttpTrigger( + "/examples/simple_client/count_token", + methods="POST", + request_body=TriggerReqBody, + ) + model_task = MyModelToolOperator() + trigger >> model_task diff --git a/examples/sdk/simple_sdk_llm_example.py b/examples/sdk/simple_sdk_llm_example.py index 1cce17667..a33c1ceec 100644 --- a/examples/sdk/simple_sdk_llm_example.py +++ b/examples/sdk/simple_sdk_llm_example.py @@ -1,13 +1,19 @@ import asyncio from dbgpt.core.awel import DAG -from dbgpt.core import BaseOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate +from dbgpt.core import ( + BaseOutputParser, + RequestBuildOperator, + PromptTemplate, + LLMOperator, +) +from dbgpt.model import OpenAILLMClient with DAG("simple_sdk_llm_example_dag") as dag: prompt_task = PromptTemplate.from_template( "Write a SQL of {dialect} to query all data of {table_name}." ) model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") - llm_task = OpenAILLM() + llm_task = LLMOperator(OpenAILLMClient()) out_parse_task = BaseOutputParser() prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task diff --git a/examples/sdk/simple_sdk_llm_sql_example.py b/examples/sdk/simple_sdk_llm_sql_example.py index 4aedf12c6..249ae57a6 100644 --- a/examples/sdk/simple_sdk_llm_sql_example.py +++ b/examples/sdk/simple_sdk_llm_sql_example.py @@ -8,10 +8,16 @@ JoinOperator, MapOperator, ) -from dbgpt.core import SQLOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate +from dbgpt.core import ( + SQLOutputParser, + LLMOperator, + RequestBuildOperator, + PromptTemplate, +) from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.datasource.operator.datasource_operator import DatasourceOperator from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator +from dbgpt.model import OpenAILLMClient def _create_temporary_connection(): @@ -115,7 +121,7 @@ def _combine_result(self, sql_result_df, model_result: Dict) -> Dict: prompt_input_task = JoinOperator(combine_function=_join_func) prompt_task = PromptTemplate.from_template(_sql_prompt()) model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") - llm_task = OpenAILLM() + llm_task = LLMOperator(OpenAILLMClient()) out_parse_task = SQLOutputParser() sql_parse_task = MapOperator(map_function=lambda x: x["sql"]) db_query_task = DatasourceOperator(connection=db_connection) From 60d646da83072e8d3d400c058efcaf46625f0a36 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 23 Dec 2023 02:01:55 +0800 Subject: [PATCH 2/5] fix(datasource): Fix SQLite get index bug --- dbgpt/datasource/rdbms/conn_clickhouse.py | 10 +++++----- dbgpt/datasource/rdbms/conn_sqlite.py | 9 ++++++++- dbgpt/rag/summary/db_summary_client.py | 4 +++- examples/awel/simple_llm_client_example.py | 13 ++----------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/dbgpt/datasource/rdbms/conn_clickhouse.py b/dbgpt/datasource/rdbms/conn_clickhouse.py index 217cc905e..2b43bd431 100644 --- a/dbgpt/datasource/rdbms/conn_clickhouse.py +++ b/dbgpt/datasource/rdbms/conn_clickhouse.py @@ -1,13 +1,8 @@ import re import sqlparse -import clickhouse_connect from typing import List, Optional, Any, Iterable, Dict from sqlalchemy import text -from urllib.parse import quote -from sqlalchemy.schema import CreateTable -from urllib.parse import quote_plus as urlquote from dbgpt.datasource.rdbms.base import RDBMSDatabase -from clickhouse_connect.driver import httputil from dbgpt.storage.schema import DBType from sqlalchemy import ( MetaData, @@ -56,6 +51,11 @@ def from_uri_db( engine_args: Optional[dict] = None, **kwargs: Any, ) -> RDBMSDatabase: + import clickhouse_connect + from clickhouse_connect.driver import httputil + + # Lazy import + big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12) client = clickhouse_connect.get_client( host=host, diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index b535cd80f..bed71bfd6 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -37,7 +37,14 @@ def get_indexes(self, table_name): """Get table indexes about specified table.""" cursor = self.session.execute(text(f"PRAGMA index_list({table_name})")) indexes = cursor.fetchall() - return [(index[1], index[3]) for index in indexes] + result = [] + for idx in indexes: + index_name = idx[1] + cursor = self.session.execute(text(f"PRAGMA index_info({index_name})")) + index_infos = cursor.fetchall() + column_names = [index_info[2] for index_info in index_infos] + result.append({"name": index_name, "column_names": column_names}) + return result def get_show_create_table(self, table_name): """Get table show create table about specified table.""" diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index ccbdc0125..affb08c60 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -1,5 +1,6 @@ import logging +import traceback from dbgpt.component import SystemApp from dbgpt._private.config import Config from dbgpt.configs.model_config import ( @@ -67,8 +68,9 @@ def init_db_summary(self): try: self.db_summary_embedding(item["db_name"], item["db_type"]) except Exception as e: + message = traceback.format_exc() logger.warn( - f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e + f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, detail: {message}' ) def init_db_profile(self, db_summary_client, dbname, embeddings): diff --git a/examples/awel/simple_llm_client_example.py b/examples/awel/simple_llm_client_example.py index 95be4d76e..2d12d0eac 100644 --- a/examples/awel/simple_llm_client_example.py +++ b/examples/awel/simple_llm_client_example.py @@ -99,19 +99,10 @@ async def transform_stream( yield output -class MyModelToolOperator(MapOperator[TriggerReqBody, Dict[str, Any]]): +class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]]): def __init__(self, llm_client: LLMClient = None, **kwargs): - super().__init__(**kwargs) self._llm_client = llm_client - - @property - def llm_client(self) -> LLMClient: - if not self._llm_client: - worker_manager = self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - self._llm_client = DefaultLLMClient(worker_manager) - return self._llm_client + MapOperator.__init__(self, **kwargs) async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: prompt_tokens = await self.llm_client.count_token( From 940d3afd83291ff206b2b933ffbef7c4e7055985 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 23 Dec 2023 09:23:38 +0800 Subject: [PATCH 3/5] fix(datasource): Fix SQLite unit test error --- dbgpt/datasource/rdbms/tests/test_conn_sqlite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py index 28655c46f..267ee1575 100644 --- a/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py +++ b/dbgpt/datasource/rdbms/tests/test_conn_sqlite.py @@ -47,7 +47,8 @@ def test_run_no_throw(db): def test_get_indexes(db): db.run("CREATE TABLE test (name TEXT);") db.run("CREATE INDEX idx_name ON test(name);") - assert db.get_indexes("test") == [("idx_name", "c")] + indexes = db.get_indexes("test") + assert indexes == [{"name": "idx_name", "column_names": ["name"]}] def test_get_indexes_empty(db): From a92321d9211a5f8481d4682f5cc6516d584dea2b Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 23 Dec 2023 10:43:30 +0800 Subject: [PATCH 4/5] build(core): Add Makefile for test/fmt/pre-commit --- Makefile | 48 ++++++++++++++++++++++++++++++ dbgpt/datasource/rdbms/base.py | 5 +++- requirements/dev-requirements.txt | 2 -- requirements/lint-requirements.txt | 11 +++++++ setup.py | 43 +++++++++++++++----------- 5 files changed, 88 insertions(+), 21 deletions(-) create mode 100644 Makefile create mode 100644 requirements/lint-requirements.txt diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..5b097459e --- /dev/null +++ b/Makefile @@ -0,0 +1,48 @@ +.DEFAULT_GOAL := help + +SHELL=/bin/bash +VENV = venv + +# Detect the operating system and set the virtualenv bin directory +ifeq ($(OS),Windows_NT) + VENV_BIN=$(VENV)/Scripts +else + VENV_BIN=$(VENV)/bin +endif + +setup: ## Set up the Python development environment + python3 -m venv $(VENV) + $(VENV_BIN)/pip install --upgrade pip + $(VENV_BIN)/pip install -r requirements/dev-requirements.txt + $(VENV_BIN)/pip install -r requirements/lint-requirements.txt + +testenv: setup ## Set up the Python test environment + $(VENV_BIN)/pip install -e ".[simple_framework]" + +.PHONY: fmt +fmt: setup ## Format Python code + $(VENV_BIN)/black . + +.PHONY: pre-commit +pre-commit: fmt test ## Run formatting and unit tests before committing + +.PHONY: test +test: testenv ## Run unit tests + $(VENV_BIN)/pytest dbgpt + +.PHONY: coverage +coverage: setup ## Run tests and report coverage + $(VENV_BIN)/pytest dbgpt --cov=dbgpt + +.PHONY: clean +clean: ## Clean up the environment + rm -rf $(VENV) + find . -type f -name '*.pyc' -delete + find . -type d -name '__pycache__' -delete + find . -type d -name '.pytest_cache' -delete + find . -type d -name '.coverage' -delete + +.PHONY: help +help: ## Display this help screen + @echo "Available commands:" + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py index bf1a594be..a8d0846f4 100644 --- a/dbgpt/datasource/rdbms/base.py +++ b/dbgpt/datasource/rdbms/base.py @@ -1,7 +1,6 @@ from __future__ import annotations import sqlparse import regex as re -import pandas as pd from urllib.parse import quote from urllib.parse import quote_plus as urlquote from typing import Any, Iterable, List, Optional, Dict @@ -383,6 +382,10 @@ def run(self, command: str, fetch: str = "all") -> List: return self.get_simple_fields(table_name) def run_to_df(self, command: str, fetch: str = "all"): + import pandas as pd + + # Pandas has too much dependence and the import time is too long + # TODO: Remove the dependency on pandas result_lst = self.run(command, fetch) colunms = result_lst[0] values = result_lst[1:] diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 3ea930c1c..b122a3ee5 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -10,7 +10,5 @@ pytest-mock pytest-recording pytesseract==0.3.10 aioresponses -# python code format, usage `black .` -black # for git hooks pre-commit \ No newline at end of file diff --git a/requirements/lint-requirements.txt b/requirements/lint-requirements.txt new file mode 100644 index 000000000..91639ca6f --- /dev/null +++ b/requirements/lint-requirements.txt @@ -0,0 +1,11 @@ +# python code format, usage `black .` +black==22.8.0 +blackdoc==0.3.7 +flake8==5.0.4 +flake8-bugbear==22.10.25 +flake8-comprehensions==3.10.0 +flake8-docstrings==1.6.0 +flake8-simplify==0.19.3 +flake8-tidy-imports==4.8.0 +isort==5.10.1 +pyupgrade==3.1.0 diff --git a/setup.py b/setup.py index 2b670b35a..0d219f497 100644 --- a/setup.py +++ b/setup.py @@ -364,23 +364,40 @@ def core_requires(): "prettytable", "cachetools", ] - - setup_spec.extras["framework"] = [ - "coloredlogs", + # Just use by DB-GPT internal, we should find the smallest dependency set for run we core unit test. + # The dependency "framework" is too large for now. + setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [ + "pydantic<2,>=1", "httpx", + "fastapi==0.98.0", + "shortuuid", + # change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2 + "SQLAlchemy>=1.4,<3", + # for cache + "msgpack", + # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit. + "pympler", "sqlparse==0.4.4", + "duckdb==0.8.1", + "duckdb-engine", + ] + # TODO: remove fschat from simple_framework + if BUILD_FROM_SOURCE: + setup_spec.extras["simple_framework"].append( + f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}" + ) + else: + setup_spec.extras["simple_framework"].append("fschat") + + setup_spec.extras["framework"] = setup_spec.extras["simple_framework"] + [ + "coloredlogs", "seaborn", # https://github.com/eosphoros-ai/DB-GPT/issues/551 "pandas==2.0.3", "auto-gpt-plugin-template", "gTTS==2.3.1", "langchain>=0.0.286", - # change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2 - "SQLAlchemy>=1.4,<3", - "fastapi==0.98.0", "pymysql", - "duckdb==0.8.1", - "duckdb-engine", "jsonschema", # TODO move transformers to default # "transformers>=4.31.0", @@ -390,20 +407,10 @@ def core_requires(): "openpyxl==3.1.2", "chardet==5.1.0", "xlrd==2.0.1", - # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit. - "pympler", "aiofiles", - # for cache - "msgpack", # for agent "GitPython", ] - if BUILD_FROM_SOURCE: - setup_spec.extras["framework"].append( - f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}" - ) - else: - setup_spec.extras["framework"].append("fschat") def knowledge_requires(): From 904aedaa1e699921ec4826044acad438f1e0655d Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 23 Dec 2023 11:11:26 +0800 Subject: [PATCH 5/5] build(core): Add more python format tools --- Makefile | 13 +++++++++++++ requirements/dev-requirements.txt | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 5b097459e..132cdc9bf 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,20 @@ testenv: setup ## Set up the Python test environment .PHONY: fmt fmt: setup ## Format Python code + # TODO: Use isort to sort Python imports. + # https://github.com/PyCQA/isort + # $(VENV_BIN)/isort . + # https://github.com/psf/black $(VENV_BIN)/black . + # TODO: Use blackdoc to format Python doctests. + # https://blackdoc.readthedocs.io/en/latest/ + # $(VENV_BIN)/blackdoc . + # TODO: Type checking of Python code. + # https://github.com/python/mypy + # $(VENV_BIN)/mypy dbgpt + # TODO: uUse flake8 to enforce Python style guide. + # https://flake8.pycqa.org/en/latest/ + # $(VENV_BIN)/flake8 dbgpt .PHONY: pre-commit pre-commit: fmt test ## Run formatting and unit tests before committing diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index b122a3ee5..dc49dd0aa 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -11,4 +11,6 @@ pytest-recording pytesseract==0.3.10 aioresponses # for git hooks -pre-commit \ No newline at end of file +pre-commit +# Type checking +mypy==0.991 \ No newline at end of file