From 9eb654d2e5b7b665515287c70a04ae562bd50c22 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 01:48:49 +0200 Subject: [PATCH 01/10] runner: add llm-generate route and pipeline --- runner/app/main.py | 7 +++ runner/app/pipelines/llm_generate.py | 50 +++++++++++++++ runner/app/routes/llm_generate.py | 55 +++++++++++++++++ runner/gen_openapi.py | 2 + runner/openapi.json | 91 ++++++++++++++++++++++++++++ runner/requirements.txt | 2 +- 6 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 runner/app/pipelines/llm_generate.py create mode 100644 runner/app/routes/llm_generate.py diff --git a/runner/app/main.py b/runner/app/main.py index 6f511420..ac0dd190 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -52,6 +52,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.upscale import UpscalePipeline return UpscalePipeline(model_id) + case "llm": + from runner.app.pipelines.llm_generate import LLMGeneratePipeline + return LLMGeneratePipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -82,6 +85,10 @@ def load_route(pipeline: str) -> any: from app.routes import upscale return upscale.router + case "llm": + from app.routes import llm + + return llm.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py new file mode 100644 index 00000000..1cab7464 --- /dev/null +++ b/runner/app/pipelines/llm_generate.py @@ -0,0 +1,50 @@ +import logging +import os +from typing import Dict, Any + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_model_dir, get_torch_device + +logger = logging.getLogger(__name__) + + +class LLMGeneratePipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + self.device = get_torch_device() + + # Load tokenizer and model + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, cache_dir=get_model_dir()) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + cache_dir=get_model_dir(), + torch_dtype=torch.float16, + device_map="auto", + ) + + # Set up generation config + self.generation_config = GenerationConfig.from_pretrained(model_id) + self.generation_config.max_length = 2048 # Adjust as needed + + def __call__(self, prompt: str, **kwargs) -> Dict[str, Any]: + # Encode the prompt + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + + # Generate response + with torch.no_grad(): + output = self.model.generate( + input_ids, + generation_config=self.generation_config, + **kwargs + ) + + # Decode the response + response = self.tokenizer.decode(output[0], skip_special_tokens=True) + + return response + + def __str__(self) -> str: + return f"LLMGeneratePipeline model_id={self.model_id}" diff --git a/runner/app/routes/llm_generate.py b/runner/app/routes/llm_generate.py new file mode 100644 index 00000000..bb6773a7 --- /dev/null +++ b/runner/app/routes/llm_generate.py @@ -0,0 +1,55 @@ +import logging +from typing import Annotated +from fastapi import APIRouter, Depends, Form, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.util import HTTPError, TextResponse, http_error + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +@router.post("/llm-generate", response_model=TextResponse, responses=RESPONSES) +@router.post("/llm-generate", response_model=TextResponse, responses=RESPONSES, include_in_schema=False) +async def llm_generate( + prompt: Annotated[str, Form()], + model_id: Annotated[str, Form()] = "", + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}" + ), + ) + + try: + result = pipeline(prompt=prompt) + return JSONResponse(content=result) + except Exception as e: + logger.error(f"LLM processing error: {str(e)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=http_error("Internal server error during LLM processing."), + ) diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index 7fde5ee3..700485f7 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -12,6 +12,7 @@ image_to_video, text_to_image, upscale, + llm_generate ) from fastapi.openapi.utils import get_openapi @@ -85,6 +86,7 @@ def write_openapi(fname, entrypoint="runner"): app.include_router(image_to_video.router) app.include_router(upscale.router) app.include_router(audio_to_text.router) + app.include_router(llm_generate.router) use_route_names_as_operation_ids(app) diff --git a/runner/openapi.json b/runner/openapi.json index 4345e565..cae12e0f 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -404,6 +404,79 @@ } ] } + }, + "/llm-generate": { + "post": { + "summary": "Llm Generate", + "operationId": "llm_generate", + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Body_llm_generate_llm_generate_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TextResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } } }, "components": { @@ -561,6 +634,24 @@ ], "title": "Body_image_to_video_image_to_video_post" }, + "Body_llm_generate_llm_generate_post": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt" + }, + "model_id": { + "type": "string", + "title": "Model Id", + "default": "" + } + }, + "type": "object", + "required": [ + "prompt" + ], + "title": "Body_llm_generate_llm_generate_post" + }, "Body_upscale_upscale_post": { "properties": { "prompt": { diff --git a/runner/requirements.txt b/runner/requirements.txt index 3852800d..21de29da 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -8,7 +8,7 @@ python-multipart==0.0.9 uvicorn==0.30.0 huggingface_hub==0.23.2 xformers==0.0.23 -triton>=2.1.0 +triton>=0.1.0 peft==0.11.1 deepcache==0.1.1 safetensors==0.4.3 From 9ecabab06ed57194fd6e127dcfd8172416dd9838 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 01:57:49 +0200 Subject: [PATCH 02/10] add llama3.1 8B to downloads --- runner/dl_checkpoints.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 9fe40837..bd0e024b 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -60,6 +60,10 @@ function download_all_models() { # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models + + # Download LLM models (Warning: large model size) + huggingface-cli download meta-llama/Meta-Llama-3.1-8B --include "original/*" --local-dir Meta-Llama-3.1-8B + } # Enable HF transfer acceleration. From a8362a759f7a7e7c19861c225ff6e6508e60bdaf Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 01:49:01 +0200 Subject: [PATCH 03/10] worker: add llm-generate container management --- runner/app/main.py | 10 +- runner/app/pipelines/llm_generate.py | 149 ++++++++++++--- runner/app/routes/llm_generate.py | 19 +- runner/app/routes/util.py | 5 + runner/dl_checkpoints.sh | 2 +- runner/openapi.json | 32 +++- worker/docker.go | 1 + worker/multipart.go | 38 ++++ worker/runner.gen.go | 272 ++++++++++++++++++++++++--- worker/worker.go | 48 +++++ 10 files changed, 509 insertions(+), 67 deletions(-) diff --git a/runner/app/main.py b/runner/app/main.py index ac0dd190..f215c377 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -52,8 +52,8 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.upscale import UpscalePipeline return UpscalePipeline(model_id) - case "llm": - from runner.app.pipelines.llm_generate import LLMGeneratePipeline + case "llm-generate": + from app.pipelines.llm_generate import LLMGeneratePipeline return LLMGeneratePipeline(model_id) case _: raise EnvironmentError( @@ -85,10 +85,10 @@ def load_route(pipeline: str) -> any: from app.routes import upscale return upscale.router - case "llm": - from app.routes import llm + case "llm-generate": + from app.routes import llm_generate - return llm.router + return llm_generate.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py index 1cab7464..0f70e1b2 100644 --- a/runner/app/pipelines/llm_generate.py +++ b/runner/app/pipelines/llm_generate.py @@ -1,50 +1,141 @@ import logging import os -from typing import Dict, Any +from typing import Dict, Any, Optional import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device +from huggingface_hub import file_download, hf_hub_download logger = logging.getLogger(__name__) +# class LLMGeneratePipeline(Pipeline): +# def __init__(self, model_id: str): +# self.model_id = model_id +# kwargs = { +# "cache_dir": get_model_dir() +# } +# self.device = get_torch_device() +# folder_name = file_download.repo_folder_name( +# repo_id=model_id, repo_type="model" +# ) +# folder_path = os.path.join(get_model_dir(), folder_name) + +# # Check for fp16 variant +# has_fp16_variant = any( +# ".fp16.safetensors" in fname +# for _, _, files in os.walk(folder_path) +# for fname in files +# ) +# if self.device != "cpu" and has_fp16_variant: +# logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) +# kwargs["torch_dtype"] = torch.float16 +# kwargs["variant"] = "fp16" + +# # Load tokenizer +# self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + +# # Load model +# self.model = AutoModelForCausalLM.from_pretrained( +# model_id, **kwargs).to(self.device) + +# # Set up generation config +# self.generation_config = self.model.generation_config + +# # Optional: Add optimizations +# sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" +# if sfast_enabled: +# logger.info( +# "LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", +# model_id, +# ) +# from app.pipelines.optim.sfast import compile_model +# self.model = compile_model(self.model) + +# def __call__(self, prompt: str, system_msg: Optional[str] = None, +# temperature: Optional[float] = None, +# max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]: +# if system_msg: +# input_text = f"{system_msg}\n\n{prompt}" +# else: +# input_text = prompt + +# input_ids = self.tokenizer.encode( +# input_text, return_tensors="pt").to(self.device) + +# # Update generation config +# gen_kwargs = {} +# if temperature is not None: +# gen_kwargs['temperature'] = temperature +# if max_tokens is not None: +# gen_kwargs['max_new_tokens'] = max_tokens + +# # Merge generation config with provided kwargs +# gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs} + +# # Generate response +# with torch.no_grad(): +# output = self.model.generate( +# input_ids, +# **gen_kwargs +# ) + +# # Decode the response +# response = self.tokenizer.decode(output[0], skip_special_tokens=True) + +# # Calculate tokens used +# tokens_used = len(output[0]) + +# return { +# "response": response.strip(), +# "tokens_used": tokens_used +# } + +# def __str__(self) -> str: +# return f"LLMPipeline model_id={self.model_id}" + + class LLMGeneratePipeline(Pipeline): def __init__(self, model_id: str): self.model_id = model_id self.device = get_torch_device() - # Load tokenizer and model - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, cache_dir=get_model_dir()) - self.model = AutoModelForCausalLM.from_pretrained( - model_id, - cache_dir=get_model_dir(), - torch_dtype=torch.float16, - device_map="auto", - ) + kwargs = { + "cache_dir": get_model_dir(), + "device_map": "auto", + "torch_dtype": torch.bfloat16 if self.device != "cpu" else torch.float32, + } - # Set up generation config - self.generation_config = GenerationConfig.from_pretrained(model_id) - self.generation_config.max_length = 2048 # Adjust as needed + logger.info(f"Loading model {model_id}") + self.pipeline = pipeline( + "text-generation", + model=model_id, + tokenizer=model_id, + **kwargs + ) - def __call__(self, prompt: str, **kwargs) -> Dict[str, Any]: - # Encode the prompt - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + def __call__(self, prompt: str, system_msg: str = None, **kwargs): + messages = [] + if system_msg: + messages.append({"role": "system", "content": system_msg}) + messages.append({"role": "user", "content": prompt}) - # Generate response - with torch.no_grad(): - output = self.model.generate( - input_ids, - generation_config=self.generation_config, - **kwargs - ) + outputs = self.pipeline( + messages, + max_new_tokens=kwargs.get("max_tokens", 256), + temperature=kwargs.get("temperature", 0.7), + ) - # Decode the response - response = self.tokenizer.decode(output[0], skip_special_tokens=True) + response = outputs[0]["generated_text"] + # Assuming the response is the last message in the conversation + response = response.split("assistant:")[-1].strip() - return response + return { + "response": response, + "tokens_used": len(self.pipeline.tokenizer.encode(response)) + } - def __str__(self) -> str: - return f"LLMGeneratePipeline model_id={self.model_id}" + def __str__(self): + return f"LLMGeneratePipeline(model_id={self.model_id})" diff --git a/runner/app/routes/llm_generate.py b/runner/app/routes/llm_generate.py index bb6773a7..d43120d3 100644 --- a/runner/app/routes/llm_generate.py +++ b/runner/app/routes/llm_generate.py @@ -1,11 +1,12 @@ import logging -from typing import Annotated +import os +from typing import Annotated, Optional from fastapi import APIRouter, Depends, Form, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, TextResponse, http_error +from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error router = APIRouter() @@ -18,11 +19,14 @@ } -@router.post("/llm-generate", response_model=TextResponse, responses=RESPONSES) -@router.post("/llm-generate", response_model=TextResponse, responses=RESPONSES, include_in_schema=False) +@router.post("/llm-generate", response_model=LlmResponse, responses=RESPONSES) +@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False) async def llm_generate( prompt: Annotated[str, Form()], model_id: Annotated[str, Form()] = "", + system_msg: Annotated[str, Form()] = None, + temperature: Annotated[float, Form()] = None, + max_tokens: Annotated[int, Form()] = None, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -45,7 +49,12 @@ async def llm_generate( ) try: - result = pipeline(prompt=prompt) + result = pipeline( + prompt=prompt, + system_msg=system_msg, + temperature=temperature, + max_tokens=max_tokens + ) return JSONResponse(content=result) except Exception as e: logger.error(f"LLM processing error: {str(e)}") diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 96736305..e0fc914a 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -34,6 +34,11 @@ class TextResponse(BaseModel): chunks: List[chunk] +class LlmResponse(BaseModel): + response: str + tokens_used: int + + class APIError(BaseModel): msg: str diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index bd0e024b..8af42a54 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -62,7 +62,7 @@ function download_all_models() { huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models # Download LLM models (Warning: large model size) - huggingface-cli download meta-llama/Meta-Llama-3.1-8B --include "original/*" --local-dir Meta-Llama-3.1-8B + huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models } diff --git a/runner/openapi.json b/runner/openapi.json index cae12e0f..83ca1bff 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -425,7 +425,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/TextResponse" + "$ref": "#/components/schemas/LlmResponse" } } } @@ -644,6 +644,18 @@ "type": "string", "title": "Model Id", "default": "" + }, + "system_msg": { + "type": "string", + "title": "System Msg" + }, + "temperature": { + "type": "number", + "title": "Temperature" + }, + "max_tokens": { + "type": "integer", + "title": "Max Tokens" } }, "type": "object", @@ -742,6 +754,24 @@ ], "title": "ImageResponse" }, + "LlmResponse": { + "properties": { + "response": { + "type": "string", + "title": "Response" + }, + "tokens_used": { + "type": "integer", + "title": "Tokens Used" + } + }, + "type": "object", + "required": [ + "response", + "tokens_used" + ], + "title": "LlmResponse" + }, "Media": { "properties": { "url": { diff --git a/worker/docker.go b/worker/docker.go index 8d7f97e0..ce510493 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -35,6 +35,7 @@ var containerHostPorts = map[string]string{ "image-to-video": "8002", "upscale": "8003", "audio-to-text": "8004", + "llm": "8005", } type DockerManager struct { diff --git a/worker/multipart.go b/worker/multipart.go index 865b9114..7a87bcae 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -240,3 +240,41 @@ func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestB return mw, nil } + +func NewLlmGenerateMultipartWriter(w io.Writer, req BodyLlmGenerateLlmGeneratePost) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + + if err := mw.WriteField("prompt", req.Prompt); err != nil { + return nil, fmt.Errorf("failed to write prompt field: %w", err) + } + + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, fmt.Errorf("failed to write model_id field: %w", err) + } + } + + if req.SystemMsg != nil { + if err := mw.WriteField("system_msg", *req.SystemMsg); err != nil { + return nil, fmt.Errorf("failed to write system_msg field: %w", err) + } + } + + if req.Temperature != nil { + if err := mw.WriteField("temperature", fmt.Sprintf("%f", *req.Temperature)); err != nil { + return nil, fmt.Errorf("failed to write temperature field: %w", err) + } + } + + if req.MaxTokens != nil { + if err := mw.WriteField("max_tokens", strconv.Itoa(*req.MaxTokens)); err != nil { + return nil, fmt.Errorf("failed to write max_tokens field: %w", err) + } + } + + if err := mw.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 0dbe8036..61c1aa8a 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -66,6 +66,15 @@ type BodyImageToVideoImageToVideoPost struct { Width *int `json:"width,omitempty"` } +// BodyLlmGenerateLlmGeneratePost defines model for Body_llm_generate_llm_generate_post. +type BodyLlmGenerateLlmGeneratePost struct { + MaxTokens *int `json:"max_tokens,omitempty"` + ModelId *string `json:"model_id,omitempty"` + Prompt string `json:"prompt"` + SystemMsg *string `json:"system_msg,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` +} + // BodyUpscaleUpscalePost defines model for Body_upscale_upscale_post. type BodyUpscaleUpscalePost struct { Image openapi_types.File `json:"image"` @@ -96,6 +105,12 @@ type ImageResponse struct { Images []Media `json:"images"` } +// LlmResponse defines model for LlmResponse. +type LlmResponse struct { + Response string `json:"response"` + TokensUsed int `json:"tokens_used"` +} + // Media defines model for Media. type Media struct { Nsfw bool `json:"nsfw"` @@ -161,6 +176,9 @@ type ImageToImageMultipartRequestBody = BodyImageToImageImageToImagePost // ImageToVideoMultipartRequestBody defines body for ImageToVideo for multipart/form-data ContentType. type ImageToVideoMultipartRequestBody = BodyImageToVideoImageToVideoPost +// LlmGenerateFormdataRequestBody defines body for LlmGenerate for application/x-www-form-urlencoded ContentType. +type LlmGenerateFormdataRequestBody = BodyLlmGenerateLlmGeneratePost + // TextToImageJSONRequestBody defines body for TextToImage for application/json ContentType. type TextToImageJSONRequestBody = TextToImageParams @@ -314,6 +332,11 @@ type ClientInterface interface { // ImageToVideoWithBody request with any body ImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // LlmGenerateWithBody request with any body + LlmGenerateWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + LlmGenerateWithFormdataBody(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + // TextToImageWithBody request with any body TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -371,6 +394,30 @@ func (c *Client) ImageToVideoWithBody(ctx context.Context, contentType string, b return c.Client.Do(req) } +func (c *Client) LlmGenerateWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewLlmGenerateRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) LlmGenerateWithFormdataBody(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewLlmGenerateRequestWithFormdataBody(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewTextToImageRequestWithBody(c.Server, contentType, body) if err != nil { @@ -521,6 +568,46 @@ func NewImageToVideoRequestWithBody(server string, contentType string, body io.R return req, nil } +// NewLlmGenerateRequestWithFormdataBody calls the generic LlmGenerate builder with application/x-www-form-urlencoded body +func NewLlmGenerateRequestWithFormdataBody(server string, body LlmGenerateFormdataRequestBody) (*http.Request, error) { + var bodyReader io.Reader + bodyStr, err := runtime.MarshalForm(body, nil) + if err != nil { + return nil, err + } + bodyReader = strings.NewReader(bodyStr.Encode()) + return NewLlmGenerateRequestWithBody(server, "application/x-www-form-urlencoded", bodyReader) +} + +// NewLlmGenerateRequestWithBody generates requests for LlmGenerate with any type of body +func NewLlmGenerateRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/llm-generate") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewTextToImageRequest calls the generic TextToImage builder with application/json body func NewTextToImageRequest(server string, body TextToImageJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader @@ -645,6 +732,11 @@ type ClientWithResponsesInterface interface { // ImageToVideoWithBodyWithResponse request with any body ImageToVideoWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ImageToVideoResponse, error) + // LlmGenerateWithBodyWithResponse request with any body + LlmGenerateWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) + + LlmGenerateWithFormdataBodyWithResponse(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) + // TextToImageWithBodyWithResponse request with any body TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) @@ -755,6 +847,32 @@ func (r ImageToVideoResponse) StatusCode() int { return 0 } +type LlmGenerateResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *LlmResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r LlmGenerateResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r LlmGenerateResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type TextToImageResponse struct { Body []byte HTTPResponse *http.Response @@ -843,6 +961,23 @@ func (c *ClientWithResponses) ImageToVideoWithBodyWithResponse(ctx context.Conte return ParseImageToVideoResponse(rsp) } +// LlmGenerateWithBodyWithResponse request with arbitrary body returning *LlmGenerateResponse +func (c *ClientWithResponses) LlmGenerateWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) { + rsp, err := c.LlmGenerateWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseLlmGenerateResponse(rsp) +} + +func (c *ClientWithResponses) LlmGenerateWithFormdataBodyWithResponse(ctx context.Context, body LlmGenerateFormdataRequestBody, reqEditors ...RequestEditorFn) (*LlmGenerateResponse, error) { + rsp, err := c.LlmGenerateWithFormdataBody(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseLlmGenerateResponse(rsp) +} + // TextToImageWithBodyWithResponse request with arbitrary body returning *TextToImageResponse func (c *ClientWithResponses) TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) { rsp, err := c.TextToImageWithBody(ctx, contentType, body, reqEditors...) @@ -1064,6 +1199,60 @@ func ParseImageToVideoResponse(rsp *http.Response) (*ImageToVideoResponse, error return response, nil } +// ParseLlmGenerateResponse parses an HTTP response from a LlmGenerateWithResponse call +func ParseLlmGenerateResponse(rsp *http.Response) (*LlmGenerateResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &LlmGenerateResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest LlmResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseTextToImageResponse parses an HTTP response from a TextToImageWithResponse call func ParseTextToImageResponse(rsp *http.Response) (*TextToImageResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1186,6 +1375,9 @@ type ServerInterface interface { // Image To Video // (POST /image-to-video) ImageToVideo(w http.ResponseWriter, r *http.Request) + // Llm Generate + // (POST /llm-generate) + LlmGenerate(w http.ResponseWriter, r *http.Request) // Text To Image // (POST /text-to-image) TextToImage(w http.ResponseWriter, r *http.Request) @@ -1222,6 +1414,12 @@ func (_ Unimplemented) ImageToVideo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } +// Llm Generate +// (POST /llm-generate) +func (_ Unimplemented) LlmGenerate(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Text To Image // (POST /text-to-image) func (_ Unimplemented) TextToImage(w http.ResponseWriter, r *http.Request) { @@ -1309,6 +1507,23 @@ func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.R handler.ServeHTTP(w, r.WithContext(ctx)) } +// LlmGenerate operation middleware +func (siw *ServerInterfaceWrapper) LlmGenerate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.LlmGenerate(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // TextToImage operation middleware func (siw *ServerInterfaceWrapper) TextToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1468,6 +1683,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-video", wrapper.ImageToVideo) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/llm-generate", wrapper.LlmGenerate) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.TextToImage) }) @@ -1481,32 +1699,34 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xZ227bOBN+FYL/f+nEhzabhe+SbLcNtoegdrsXRWAw0thmK5FaHtJ6A7/7gkNZomQp", - "cpDGC2R9Zcsaznxz+IZD+o5GMs2kAGE0Hd9RHS0hZfj17OrylVJSue+ZkhkowwHfpHrhPgw3CdAxfacX", - "tEfNKnMP2iguFnS97lEFf1muIKbjL7jkulcsKXQX6+TNV4gMXffouYxXM2ZjLmdGzgz8MLWnTGqzDQpl", - "3Je5VCkzdExvuGBqRQOrKLIFtUdTGUMy47FbHsOc2cStD1a+cwLkMu7006MIPN3Nm7Yw8JQtwIn6L7XH", - "5kAsLI+ZiGCmI+YgBC6dHp+UyF7ncmSCcgUEYdMbUA4CWrk/pJco0hBSj/AeLMMQC6oh3YgekageFbBg", - "ht/CLFMyzUyrjve5HLnyck2qbOpzoGcZqCaFw0CfTQk6qMkVqC2tXBhYePdQrZiDAoyZgUxXlQ4GNbUb", - "YTJB4SalJbjNyna/NJuDWc2iJUTfKpaNslCanqAYuUCxQs2NlAkwgXoA4tDixD03gdNGgViYZcXY4PjX", - "wNZGYqscatTLNl75sq1zcAcqdbLwlscg64/NLJzXUvdLCef3lkQtgS+W1So6OQ3WvfHvm5Y+hqmP4lQq", - "DZdidmOjb2DqSoaj01CLkyTnKFnRFhJAcg0zZhezlsIYjAICOGFyZhekvUa6OTU6eTil9k6T7zyuhWI4", - "GL0sLf2J77dX1ijSwYz28m5jhs2wsRefzVz416qzK/enJ8+qnT6sITbmriHRb6bTq5ZBMAbDeOK+/V/B", - "nI7p//rlONnPZ8l+MezVAebLA2ClrRYgn1nCY+Y6SSckbiDVXdjq+tYllt+8pgIIU4qt0IcQbV1BE25g", - "iVlebIqgilcbZmy1KumHP2i4/6FA0+BZbgylgQb7yK2PoDMpNLSwU+8csXcQcxbGyY82TXHaaj06zHUV", - "VgNub2kLr9Dz7yEZ3rvnR3VXq5JQ7pNKOud8izLaa0REgWceeINHU/hh2hMRLa34tnsiUDxMxIVfX09E", - "j7pzRuigg9HpofFCOajAu4oTLU5OJWb3iinmHXmqI0o5M+0wJf3HTw8nz+3wUExFDxyDcqdqNV2t2YbC", - "7tx7EhlV2MvE6sOcjr/cbcXqbgvidUDktzJCMw1Url+9gNYtg5P/oRRFzGTqfu2ivvPDm8olg0jtsN99", - "dnNje5ubK5bW9psHbjz19rY5V3nFHRtRbj50qYK3wSHfabcc2a2tOjspaMPSLHQ1wD0t3ndAN6GgMxY4", - "4TFugUc6RVZxs5q4OHrkbnA5B6ZAFXd+yEH/U6FkaUxG104HF3PpKa0jxTMszjE9E4RlWcJ9tRIjibKC", - "nF2SjGeQcOGTsSlqfgsZgHLvP1oh0NAtKO11DY6HxwMXLZmBYBmnY/oCf+rRjJklwu7jzdmRkUeb0G/O", - "Gy4tCOIy3tzzTWWeDxdB0MbNvLjLSmFA4KrUJoZnTJm+O5gcxcyw8g60qxx3u9hbV3PoOiH+4IsNvRoN", - "BjVcQVD7X7ULz66gKnsz2q5mbGKjCLSe24SUYj368idCKEf4BvvnLCYffT683eF+7H4SzJqlVPxviNHw", - "8MV+DOfOklfCcLMiUynJW6YWPuqj0U8FsXWW2YZTipDivHOyr+RfCgNKsIRMQN2CIuWhcNOicK8Mm9OX", - "6/V1j2qbpkytNswmU0mQ225pf4mHH5wqoaEX+LMRfULOhaevXSm3Dp3KIaI3OBa6DlfcmTS3OBxV8onl", - "iXvcDhene+5y1ZPjoc21t7lDh3loh/H/RE2lP3PVSIk3op2kxHlyX6Rsv7PdMymrU/SBlAdSPgEpPbWQ", - "lG7G3mGjDE7291LycTN39e7gsB0emPdMmOeKu7Yb5v8XtVPuUy7wtDtg499XB+YdmPdMmLdh0dqvcmo0", - "LqpaKq7VLhJpY3Ih09QKblbkNTPwna1o/vcWXubpcb8fK2Dp0cK/PU7y5ceRW07X1+t/AgAA///2pVcb", - "EigAAA==", + "H4sIAAAAAAAC/+xZW2/bOhL+KwR3H53YSZvNwm9JttsGm7ZB7XQfikBgpLHNRiK1vCTxBv7vByR1oW6V", + "jDQ+QI6fbFHDmW/I+YYz1DMOeZJyBkxJPH3GMlxBQuzfs+vLD0JwYf6ngqcgFAX7JpFL86OoigFP8We5", + "xCOs1ql5kEpQtsSbzQgL+J+mAiI8/WGn3I6KKYXuYh6/+wmhwpsRPufROiA6ojxQPFDwpGpPKZeqCcrK", + "mD8LLhKi8BTfUUbEGntWrUgD6ggnPII4oJGZHsGC6NjM92Z+NgLoMur106HwPB3mTdcy0IQswYi6P7XH", + "9oVYahoRFkIgQ2IgeC6dHp6UyD5mcmhm5QoITCd3IAwEa+XXS3ppRVqW1CH8BZYjH4tVg/oRvWCjRpjB", + "kij6AEEqeJKqTh1fMjl07eTaVOnE7YEMUhBtCo88fTpB1kGJrkE0tFKmYOncs2rZAgTYNVOQyqrSyaSm", + "NhdGMyvcprQEl8/s9kuSBah1EK4gvK9YVkJDaXpmxdCFFSvU3HEeA2FWD0DkW5yZ5zZwUglgS7WqGJsc", + "/tOzlUs0wqFGvTT3yoVtnYMDqNTLwgcaAa8/trNwUdu6f5Rw/t2xUSugy1U1ik5OvXmf3Pu2qS9h6os4", + "lXBFOQvudHgPqq7k6PjU12Ik0bmVrGjzCcCphIDoZdARGJNjjwBGGJ3pJeqOkX5OHZ9sT6md0+SRRrWl", + "OJocvy8t/de+b86sUaSHGd3h3cWMOE6CJTAQREH1oZ0VCXkKFL8HJisFBHlCczfa5vyLAnSr7LeWCpKg", + "Vt7M7ChqrXJGWEGSGo+1AH/S3BsemLjq29Kztl1bolN71ha/7RvxpyWMPjqenrypE267M6p171o2+tN8", + "ft1Rm0egCI3Nv78LWOAp/tu4rPDHWXk/LurvOsBsugestNUB5DuJaURMcu+FRBUksg9bXd+mxPIvp6kA", + "QoQga+uDj7auoA03kFitLvIgqOKViihdjUr89T/YL0msQFsvUJ7VpYEW+5Zb30CmnEnoYKccvGKfIaLE", + "XydXbbatU+M0kP5eV2G14L6Kk27UwnuTa2wq81KnTfmBllUmuZMA3cghhBKefk+d55MPucUjt3YNX5hc", + "PPqgvpjnFx3hWsS+3I2Ie5tJbWWk02gReX454C0ezeFJdW9SuNLsfnhoWXE/tC7c/HpomYPwSVVPwCfV", + "66FyQhkoz7uKEx1OzrmN12siiHPktfrgsjAfUIr/xVvUk7fWoRal95a1drOqa8ZsS2D3nqYxDyvsJWz9", + "dYGnP54ba/XcgHjrEfmKh9ZMC5Xr93sgZUcp6AZKUYsZzc1oH/WNH85UJumt1IAT/LtpTrrT3EKQpHaC", + "bnmU1tNb3rw7xT1Ha2bed6mCt8Uhl2kbjgxLq8ZOAlKRJPVd9XDPi/c90JUvaIx5TjiMDfCWTqEWVK1n", + "Zh0dclOKnQMRIIqLZctBN1QoWSmV4o3RQdmCO0rLUNDUBucUnzFE0jSmLlqR4khohs4uUUpTiClzm5EH", + "NX2AFECY9980Y9bQAwjpdE0Ojw4nZrV4CoykFE/xOzs0wilRKwt7bK9nDxQ/yJc+76C47egoZ5dRfpk8", + "59l+mBUEqUwVb09ZzhQwOyvRsaIpEWpsWq2DiChSXrT3heOw2+NNdQ9NJrQDLtisV8eTSQ2Xt6jjn9Is", + "z1BQlbPZ2q7u2EyHIUi50DEqxUb4/W+EUDYlLfbPSYS+uf1wdo92Y/eGEa1WXND/Q2QNH73bjeHMWfSB", + "KarWaM45uiJi6Vb9+Pi3gmh0Z004pQgqOriTXW3+JVMgGInRDMQDCFS2uXmKsmeln5x+3G5uR1jqJCFi", + "nTMbzTmy3DZTxyvbztmqElpygev28Ctyzu8nh1Ju4zuVQbTe2LLQZLjiFqg9xdlSJatYXjnHDbid33GW", + "q/bC+zTXneb2GWbbDOM+d86567lqpLTX7r2ktPXkrkjZ/WFgx6SsVtF7Uu5J+QqkdNSypIzj5CD/7NJN", + "yas4+ZgL/YqRvu9PB4+PjweWmVrEwEIeuQuJLfjZ84Vox9z0L1r3zNwz8/cx8ypOUEEwy0vT+w4oYL0b", + "t8HE3L4Xrt7p7cvUPe/eCO9McNeq1OzLdDflbjKB161MWz+U75m3Z94bYV7Ooo2bZdRIO6lqqbjuvoi5", + "jtAFTxLNqFqjj0TBI1nj7LOzvWSX0/E4EkCSg6V7exhn0w9DMx1vbjd/BAAA///6mbgGDy4AAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 7877f6dd..0dadceb9 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -304,6 +304,54 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return resp.JSON200, nil } +func (w *Worker) LlmGenerate(ctx context.Context, req BodyLlmGenerateLlmGeneratePost) (*LlmResponse, error) { + c, err := w.borrowContainer(ctx, "llm-generate", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + var buf bytes.Buffer + mw, err := NewLlmGenerateMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.LlmGenerateWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 400", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 400") + } + + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 401", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 401") + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 500", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 500") + } + + return resp.JSON200, nil +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags) From 0619926c76968c912520dd737ba1fff58efe4601 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 13:19:21 +0200 Subject: [PATCH 04/10] update transformers --- runner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runner/requirements.txt b/runner/requirements.txt index 21de29da..c0ae34b1 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,6 +1,6 @@ diffusers==0.29.2 accelerate==0.30.1 -transformers==4.41.1 +transformers==4.43.3 fastapi==0.111.0 pydantic==2.7.2 Pillow==10.3.0 From 922f9d2c99d877da2da73ede0ef9305f3ed5bdf1 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 14:40:39 +0200 Subject: [PATCH 05/10] llm: support streamed responses --- runner/app/pipelines/llm_generate.py | 226 +++++++++++++-------------- runner/app/routes/llm_generate.py | 68 ++++++-- runner/check_torch_cuda.py | 18 +++ runner/openapi.json | 19 ++- worker/docker.go | 2 +- worker/runner.gen.go | 59 +++---- worker/worker.go | 136 ++++++++++++---- 7 files changed, 341 insertions(+), 187 deletions(-) create mode 100644 runner/check_torch_cuda.py diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py index 0f70e1b2..93025e0a 100644 --- a/runner/app/pipelines/llm_generate.py +++ b/runner/app/pipelines/llm_generate.py @@ -1,141 +1,133 @@ +import asyncio import logging import os -from typing import Dict, Any, Optional +from typing import Dict, Any, List, Optional import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device from huggingface_hub import file_download, hf_hub_download +from threading import Thread +from typing import AsyncGenerator, Union, Dict, Any, Optional, List logger = logging.getLogger(__name__) -# class LLMGeneratePipeline(Pipeline): -# def __init__(self, model_id: str): -# self.model_id = model_id -# kwargs = { -# "cache_dir": get_model_dir() -# } -# self.device = get_torch_device() -# folder_name = file_download.repo_folder_name( -# repo_id=model_id, repo_type="model" -# ) -# folder_path = os.path.join(get_model_dir(), folder_name) - -# # Check for fp16 variant -# has_fp16_variant = any( -# ".fp16.safetensors" in fname -# for _, _, files in os.walk(folder_path) -# for fname in files -# ) -# if self.device != "cpu" and has_fp16_variant: -# logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) -# kwargs["torch_dtype"] = torch.float16 -# kwargs["variant"] = "fp16" - -# # Load tokenizer -# self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - -# # Load model -# self.model = AutoModelForCausalLM.from_pretrained( -# model_id, **kwargs).to(self.device) - -# # Set up generation config -# self.generation_config = self.model.generation_config - -# # Optional: Add optimizations -# sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" -# if sfast_enabled: -# logger.info( -# "LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", -# model_id, -# ) -# from app.pipelines.optim.sfast import compile_model -# self.model = compile_model(self.model) - -# def __call__(self, prompt: str, system_msg: Optional[str] = None, -# temperature: Optional[float] = None, -# max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]: -# if system_msg: -# input_text = f"{system_msg}\n\n{prompt}" -# else: -# input_text = prompt - -# input_ids = self.tokenizer.encode( -# input_text, return_tensors="pt").to(self.device) - -# # Update generation config -# gen_kwargs = {} -# if temperature is not None: -# gen_kwargs['temperature'] = temperature -# if max_tokens is not None: -# gen_kwargs['max_new_tokens'] = max_tokens - -# # Merge generation config with provided kwargs -# gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs} - -# # Generate response -# with torch.no_grad(): -# output = self.model.generate( -# input_ids, -# **gen_kwargs -# ) - -# # Decode the response -# response = self.tokenizer.decode(output[0], skip_special_tokens=True) - -# # Calculate tokens used -# tokens_used = len(output[0]) - -# return { -# "response": response.strip(), -# "tokens_used": tokens_used -# } - -# def __str__(self) -> str: -# return f"LLMPipeline model_id={self.model_id}" - - class LLMGeneratePipeline(Pipeline): def __init__(self, model_id: str): self.model_id = model_id - self.device = get_torch_device() - kwargs = { - "cache_dir": get_model_dir(), - "device_map": "auto", - "torch_dtype": torch.bfloat16 if self.device != "cpu" else torch.float32, + "cache_dir": get_model_dir() } - - logger.info(f"Loading model {model_id}") - self.pipeline = pipeline( - "text-generation", - model=model_id, - tokenizer=model_id, - **kwargs + self.device = get_torch_device() + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model" ) + folder_path = os.path.join(get_model_dir(), folder_name) - def __call__(self, prompt: str, system_msg: str = None, **kwargs): - messages = [] - if system_msg: - messages.append({"role": "system", "content": system_msg}) - messages.append({"role": "user", "content": prompt}) - - outputs = self.pipeline( - messages, - max_new_tokens=kwargs.get("max_tokens", 256), - temperature=kwargs.get("temperature", 0.7), + # Check for fp16 variant + has_fp16_variant = any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files ) + if self.device != "cpu" and has_fp16_variant: + logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + elif self.device != "cpu": + kwargs["torch_dtype"] = torch.bfloat16 - response = outputs[0]["generated_text"] - # Assuming the response is the last message in the conversation - response = response.split("assistant:")[-1].strip() + # Add device mapping + kwargs["device_map"] = "auto" - return { - "response": response, - "tokens_used": len(self.pipeline.tokenizer.encode(response)) - } + logger.info(f"Loading model {model_id}") + self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + + # Set up generation config + self.generation_config = self.model.generation_config + + self.terminators = [ + self.tokenizer.eos_token_id, + self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + ] + + # Optional: Add optimizations + sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" + if sfast_enabled: + logger.info( + "LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", + model_id, + ) + from app.pipelines.optim.sfast import compile_model + self.model = compile_model(self.model) + + async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + conversation = [] + if system_msg: + conversation.append({"role": "system", "content": system_msg}) + if history: + for user, assistant in history: + conversation.extend([{"role": "user", "content": user}, { + "role": "assistant", "content": assistant}]) + conversation.append({"role": "user", "content": prompt}) + + input_ids = self.tokenizer.apply_chat_template( + conversation, return_tensors="pt").to(self.model.device) + attention_mask = torch.ones_like(input_ids) + + max_new_tokens = kwargs.get("max_tokens", 256) + temperature = kwargs.get("temperature", 0.7) + + streamer = TextIteratorStreamer( + self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + + # Start with the generation config + generate_kwargs = self.generation_config.to_dict() + # Update with our specific parameters + generate_kwargs.update({ + "input_ids": input_ids, + "attention_mask": attention_mask, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": True, + "temperature": temperature, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.eos_token_id, + }) + + # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash. + if temperature == 0: + generate_kwargs['do_sample'] = False + + # Start generation in a separate thread + thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) + thread.start() + + total_tokens = 0 + try: + for text in streamer: + total_tokens += 1 + yield text + await asyncio.sleep(0) # Allow other tasks to run + except Exception as e: + logger.error(f"Error during streaming: {str(e)}") + raise + + input_length = input_ids.size(1) + yield {"tokens_used": input_length + total_tokens} def __str__(self): return f"LLMGeneratePipeline(model_id={self.model_id})" + + def model_generate_wrapper(self, **kwargs): + try: + logger.debug("Entering model.generate") + with torch.cuda.amp.autocast(): # Use automatic mixed precision + self.model.generate(**kwargs) + logger.debug("Exiting model.generate") + except Exception as e: + logger.error(f"Error in model.generate: {str(e)}", exc_info=True) + raise diff --git a/runner/app/routes/llm_generate.py b/runner/app/routes/llm_generate.py index d43120d3..07e70ff8 100644 --- a/runner/app/routes/llm_generate.py +++ b/runner/app/routes/llm_generate.py @@ -1,32 +1,37 @@ import logging import os -from typing import Annotated, Optional -from fastapi import APIRouter, Depends, Form, status -from fastapi.responses import JSONResponse +from typing import Annotated, Optional, List +from fastapi import APIRouter, Depends, Form, status, Request +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error +import json router = APIRouter() logger = logging.getLogger(__name__) RESPONSES = { + status.HTTP_200_OK: {"model": LlmResponse}, status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, } -@router.post("/llm-generate", response_model=LlmResponse, responses=RESPONSES) +@router.post("/llm-generate", + response_model=LlmResponse, responses=RESPONSES) @router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False) async def llm_generate( prompt: Annotated[str, Form()], model_id: Annotated[str, Form()] = "", - system_msg: Annotated[str, Form()] = None, - temperature: Annotated[float, Form()] = None, - max_tokens: Annotated[int, Form()] = None, + system_msg: Annotated[str, Form()] = "", + temperature: Annotated[float, Form()] = 0.7, + max_tokens: Annotated[int, Form()] = 256, + history: Annotated[str, Form()] = "[]", # We'll parse this as JSON + stream: Annotated[bool, Form()] = False, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -49,16 +54,57 @@ async def llm_generate( ) try: - result = pipeline( + history_list = json.loads(history) + if not isinstance(history_list, list): + raise ValueError("History must be a JSON array") + + generator = pipeline( prompt=prompt, - system_msg=system_msg, + history=history_list, + system_msg=system_msg if system_msg else None, temperature=temperature, max_tokens=max_tokens ) - return JSONResponse(content=result) + + if stream: + return StreamingResponse(stream_generator(generator), media_type="text/event-stream") + else: + full_response = "" + async for chunk in generator: + if isinstance(chunk, dict): + tokens_used = chunk["tokens_used"] + break + full_response += chunk + + return LlmResponse(response=full_response, tokens_used=tokens_used) + + except json.JSONDecodeError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "Invalid JSON format for history"} + ) + except ValueError as ve: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(ve)} + ) except Exception as e: logger.error(f"LLM processing error: {str(e)}") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=http_error("Internal server error during LLM processing."), + content={"detail": "Internal server error during LLM processing."} ) + + +async def stream_generator(generator): + try: + async for chunk in generator: + if isinstance(chunk, dict): # This is the final result + yield f"data: {json.dumps(chunk)}\n\n" + break + else: + yield f"data: {json.dumps({'chunk': chunk})}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Streaming error: {str(e)}") + yield f"data: {json.dumps({'error': str(e)})}\n\n" diff --git a/runner/check_torch_cuda.py b/runner/check_torch_cuda.py new file mode 100644 index 00000000..eaa297a1 --- /dev/null +++ b/runner/check_torch_cuda.py @@ -0,0 +1,18 @@ +import torch +import subprocess + +print(f"PyTorch version: {torch.__version__}") +print(f"CUDA available: {torch.cuda.is_available()}") +if torch.cuda.is_available(): + print(f"CUDA version: {torch.version.cuda}") + +# Check system CUDA version +try: + nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") + cuda_version = nvcc_output.split("release ")[-1].split(",")[0] + print(f"System CUDA version: {cuda_version}") +except: + print("Unable to check system CUDA version") + +# Print the current device +print(f"Current device: {torch.cuda.get_device_name(0)}") diff --git a/runner/openapi.json b/runner/openapi.json index 83ca1bff..982dd018 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -647,15 +647,28 @@ }, "system_msg": { "type": "string", - "title": "System Msg" + "title": "System Msg", + "default": "" }, "temperature": { "type": "number", - "title": "Temperature" + "title": "Temperature", + "default": 0.7 }, "max_tokens": { "type": "integer", - "title": "Max Tokens" + "title": "Max Tokens", + "default": 256 + }, + "history": { + "type": "string", + "title": "History", + "default": "[]" + }, + "stream": { + "type": "boolean", + "title": "Stream", + "default": false } }, "type": "object", diff --git a/worker/docker.go b/worker/docker.go index ce510493..31e9a5d3 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -35,7 +35,7 @@ var containerHostPorts = map[string]string{ "image-to-video": "8002", "upscale": "8003", "audio-to-text": "8004", - "llm": "8005", + "llm-generate": "8005", } type DockerManager struct { diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 61c1aa8a..e93fb2da 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -68,9 +68,11 @@ type BodyImageToVideoImageToVideoPost struct { // BodyLlmGenerateLlmGeneratePost defines model for Body_llm_generate_llm_generate_post. type BodyLlmGenerateLlmGeneratePost struct { + History *string `json:"history,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"` ModelId *string `json:"model_id,omitempty"` Prompt string `json:"prompt"` + Stream *bool `json:"stream,omitempty"` SystemMsg *string `json:"system_msg,omitempty"` Temperature *float32 `json:"temperature,omitempty"` } @@ -1699,34 +1701,35 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xZW2/bOhL+KwR3H53YSZvNwm9JttsGm7ZB7XQfikBgpLHNRiK1vCTxBv7vByR1oW6V", - "jDQ+QI6fbFHDmW/I+YYz1DMOeZJyBkxJPH3GMlxBQuzfs+vLD0JwYf6ngqcgFAX7JpFL86OoigFP8We5", - "xCOs1ql5kEpQtsSbzQgL+J+mAiI8/WGn3I6KKYXuYh6/+wmhwpsRPufROiA6ojxQPFDwpGpPKZeqCcrK", - "mD8LLhKi8BTfUUbEGntWrUgD6ggnPII4oJGZHsGC6NjM92Z+NgLoMur106HwPB3mTdcy0IQswYi6P7XH", - "9oVYahoRFkIgQ2IgeC6dHp6UyD5mcmhm5QoITCd3IAwEa+XXS3ppRVqW1CH8BZYjH4tVg/oRvWCjRpjB", - "kij6AEEqeJKqTh1fMjl07eTaVOnE7YEMUhBtCo88fTpB1kGJrkE0tFKmYOncs2rZAgTYNVOQyqrSyaSm", - "NhdGMyvcprQEl8/s9kuSBah1EK4gvK9YVkJDaXpmxdCFFSvU3HEeA2FWD0DkW5yZ5zZwUglgS7WqGJsc", - "/tOzlUs0wqFGvTT3yoVtnYMDqNTLwgcaAa8/trNwUdu6f5Rw/t2xUSugy1U1ik5OvXmf3Pu2qS9h6os4", - "lXBFOQvudHgPqq7k6PjU12Ik0bmVrGjzCcCphIDoZdARGJNjjwBGGJ3pJeqOkX5OHZ9sT6md0+SRRrWl", - "OJocvy8t/de+b86sUaSHGd3h3cWMOE6CJTAQREH1oZ0VCXkKFL8HJisFBHlCczfa5vyLAnSr7LeWCpKg", - "Vt7M7ChqrXJGWEGSGo+1AH/S3BsemLjq29Kztl1bolN71ha/7RvxpyWMPjqenrypE267M6p171o2+tN8", - "ft1Rm0egCI3Nv78LWOAp/tu4rPDHWXk/LurvOsBsugestNUB5DuJaURMcu+FRBUksg9bXd+mxPIvp6kA", - "QoQga+uDj7auoA03kFitLvIgqOKViihdjUr89T/YL0msQFsvUJ7VpYEW+5Zb30CmnEnoYKccvGKfIaLE", - "XydXbbatU+M0kP5eV2G14L6Kk27UwnuTa2wq81KnTfmBllUmuZMA3cghhBKefk+d55MPucUjt3YNX5hc", - "PPqgvpjnFx3hWsS+3I2Ie5tJbWWk02gReX454C0ezeFJdW9SuNLsfnhoWXE/tC7c/HpomYPwSVVPwCfV", - "66FyQhkoz7uKEx1OzrmN12siiHPktfrgsjAfUIr/xVvUk7fWoRal95a1drOqa8ZsS2D3nqYxDyvsJWz9", - "dYGnP54ba/XcgHjrEfmKh9ZMC5Xr93sgZUcp6AZKUYsZzc1oH/WNH85UJumt1IAT/LtpTrrT3EKQpHaC", - "bnmU1tNb3rw7xT1Ha2bed6mCt8Uhl2kbjgxLq8ZOAlKRJPVd9XDPi/c90JUvaIx5TjiMDfCWTqEWVK1n", - "Zh0dclOKnQMRIIqLZctBN1QoWSmV4o3RQdmCO0rLUNDUBucUnzFE0jSmLlqR4khohs4uUUpTiClzm5EH", - "NX2AFECY9980Y9bQAwjpdE0Ojw4nZrV4CoykFE/xOzs0wilRKwt7bK9nDxQ/yJc+76C47egoZ5dRfpk8", - "59l+mBUEqUwVb09ZzhQwOyvRsaIpEWpsWq2DiChSXrT3heOw2+NNdQ9NJrQDLtisV8eTSQ2Xt6jjn9Is", - "z1BQlbPZ2q7u2EyHIUi50DEqxUb4/W+EUDYlLfbPSYS+uf1wdo92Y/eGEa1WXND/Q2QNH73bjeHMWfSB", - "KarWaM45uiJi6Vb9+Pi3gmh0Z004pQgqOriTXW3+JVMgGInRDMQDCFS2uXmKsmeln5x+3G5uR1jqJCFi", - "nTMbzTmy3DZTxyvbztmqElpygev28Ctyzu8nh1Ju4zuVQbTe2LLQZLjiFqg9xdlSJatYXjnHDbid33GW", - "q/bC+zTXneb2GWbbDOM+d86567lqpLTX7r2ktPXkrkjZ/WFgx6SsVtF7Uu5J+QqkdNSypIzj5CD/7NJN", - "yas4+ZgL/YqRvu9PB4+PjweWmVrEwEIeuQuJLfjZ84Vox9z0L1r3zNwz8/cx8ypOUEEwy0vT+w4oYL0b", - "t8HE3L4Xrt7p7cvUPe/eCO9McNeq1OzLdDflbjKB161MWz+U75m3Z94bYV7Ooo2bZdRIO6lqqbjuvoi5", - "jtAFTxLNqFqjj0TBI1nj7LOzvWSX0/E4EkCSg6V7exhn0w9DMx1vbjd/BAAA///6mbgGDy4AAA==", + "H4sIAAAAAAAC/+xZX1PjOBL/KirdPQYSmOG4yhtwczPUwgw1CbMPFJUSdifRYEteSQayVL77liT/kW0Z", + "O8WQrWLzlNhudf9a6l+rW3rGAY8TzoApicfPWAZLiIn5e3J1/kkILvT/RPAEhKJgvsRyoX8UVRHgMb6U", + "CzzAapXoB6kEZQu8Xg+wgD9SKiDE4xsz5HZQDCl0F+P43U8IFF4P8CkPVzOShpTPFJ8peFK1p4RL1QRl", + "ZPSfORcxUXiM7ygjYoUdq0akAXWAYx5CNKOhHh7CnKSRHu+MvNQC6Dzs9NOicDzt503bNNCYLECL2j+1", + "R/9ELFIaEhbATAZEQ3BcOt4/KpF9zuTQxMgVEFga34HQEIyVl6f03Ih4ptQifAHLgYvFqEHdiF6xUAPM", + "YEEUfYBZInicqFYdXzM5dGXlfKrS2K6BnCUgfAoPHH1pjIyDEl2BaGilTMHCumfUsjkIMHOmIJFVpaNR", + "TW0ujCZG2Ke0BJePbPdLkjmo1SxYQnBfsaxECqXpiRFDZ0asUHPHeQSEGT0AoWtxop994KQSwBZqWTE2", + "2v+vYyuXaIRDjXpJ7pUN2zoHe1Cpk4UPNARef/SzcF5buv+UcP7fslBLoItlNYqOjp1xX+x339DXMPVV", + "nIq5opzN7tLgHlRdycHhsatFS6JTI1nR5hKAUwkzki5mLYExOnQIoIXRSbpA7THSzanDo80ptXWaPNKw", + "NhUHo8OPpaXfzffmyBpFOpjRHt5tzIiieLYABoIoqD74WbGkUnGxqobaza0TbF8yCV+skaeZ4vfA6gvo", + "cOSSPKGplfHN46tifaNEqgSQuGJmTiIJ1bxGYn9YrKSCeJaVWV6cEyOCvKXXACuIE70MqYAafxxCTh2h", + "nrm1Hjkdy98WNWliyoHi1x8rf1tO68oYx0fvahPebBv1rp1nob9Mp1ct7UMIitBI//u3gDke438NyyZk", + "mHUgw6JFqAPMhjvASlstQH6QiIZE7z+dkKiCWHZhq+tbl1j+ZzUVQIgQZGV8cNHWFfhwA4nU8iwPgipe", + "qYhKq1GJv/3m5gcr4GtXynKiNOCxb7j1HWTCmYQWdsreM3YJISXuPNmC2DdPjQ1LumtdheXBfRHF7aiF", + "8yXX2FTmJFKzlcxSWWWS3WHQtexDKOHod9Q5PrmQPR7ZuWv4wuT80QX1VT+/qspIReTKXYuos99NjYy0", + "Gg0ixy8L3OPRFJ5U+yIFy5Td9w8tI+6G1pkdXw8tvS0+VVKwhtHpobJCGajbQWX0S8umv0+5idcrIoh1", + "5K1a9bJ36NEt/MO76KP31kQX3cGG7UCzqmvGrCewO3fTiAcV9hK2+jbH45vnxlw9NyDeOkS+4IEx46Fy", + "/QgSpGwpBe2LUtRgRlP9tov62g9rKpN0ZqrHDv5D90/taW4uSFzbQTfcSuvpLT9fsIo7ttbMvOtSBa/H", + "IZtpG470S6vaTgxSkThxXXVwT4vvHdCVK6iNOU5YjA3whk5BKqhaTfQ8WuS6FDsFIkAUZ9+Gg/ZVoWSp", + "VILXWgdlc24pLQNBExOcY3zCEEmSiNpoRYojkTJ0co4SmkBEmV2MPKjpAyQAQn//njJmDD2AkFbXaP9g", + "f6RniyfASELxGH8wrwY4IWppYA/NCfKe4nv51OcdFDcdHeXsPMzPu6c8Ww89gyCVruLNLsuZAmZGxWmk", + "aEKEGupWay8kipR3AV3h2O+Ae11dQ50JzQsbbMarw9GohsuZ1OFPqaenL6jK3mxsV1dskgYBSDlPI1SK", + "DfDHXwihbEo89k9JiL7b9bB2D7Zj95qRVC25oH9CaAwffNiO4cxZ9IkpqlZoyjm6IGJhZ/3w8JeCaHRn", + "TTilCCo6uKNtLf45UyAYidAExAMIVLa5eYoye6WbnG5u17cDLNM4JmKVMxtNOTLc1kOHS9POmaoSPLnA", + "dnv4DTnn9pN9Kbd2ncogGm9MWagzXHEK5E9xplTJKpY3znE9LhC2nOWqvfAuzbWnuV2G2TTD2BvZKbc9", + "V42U5magk5SmntwWKdvvLrZMymoVvSPljpRvQEpLLUPKKIr38muXdkpeRPHnXOglRrq+P+09Pj7uGWam", + "IgIW8NAeSGzAz44boi1z0z1o3TFzx8xfx8yLKEYFwQwvde/bo4B1Ttx6E3PzXrh6prcrU3e8eye808Fd", + "q1Kzm+l2yl1nAm9bmXovynfM2zHvnTAvZ9HajtJqpBlUtVQcd59FPA3RGY/jlFG1Qp+Jgkeywtm1szlk", + "l+PhMBRA4r2F/bofZcP3Az0cr2/XfwUAAP//+w/BrrIuAAA=", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 0dadceb9..41f91610 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -1,11 +1,15 @@ package worker import ( + "bufio" "bytes" "context" "encoding/json" "errors" + "fmt" + "io" "log/slog" + "net/http" "strconv" "sync" ) @@ -304,13 +308,23 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return resp.JSON200, nil } -func (w *Worker) LlmGenerate(ctx context.Context, req BodyLlmGenerateLlmGeneratePost) (*LlmResponse, error) { +func (w *Worker) LlmGenerate(ctx context.Context, req LlmGenerateFormdataRequestBody) (interface{}, error) { + slog.Info("Incoming request %v", req) c, err := w.borrowContainer(ctx, "llm-generate", *req.ModelId) if err != nil { return nil, err } + if c == nil { + return nil, errors.New("borrowed container is nil") + } + if c.Client == nil { + return nil, errors.New("container client is nil") + } + defer w.returnContainer(c) + slog.Info("Container borrowed successfully", "model_id", *req.ModelId) + var buf bytes.Buffer mw, err := NewLlmGenerateMultipartWriter(&buf, req) if err != nil { @@ -322,34 +336,11 @@ func (w *Worker) LlmGenerate(ctx context.Context, req BodyLlmGenerateLlmGenerate return nil, err } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) - if err != nil { - return nil, err - } - slog.Error("llm-generate container returned 400", slog.String("err", string(val))) - return nil, errors.New("llm-generate container returned 400") + if req.Stream != nil && *req.Stream { + return w.handleStreamingResponse(ctx, resp) } - if resp.JSON401 != nil { - val, err := json.Marshal(resp.JSON401) - if err != nil { - return nil, err - } - slog.Error("llm-generate container returned 401", slog.String("err", string(val))) - return nil, errors.New("llm-generate container returned 401") - } - - if resp.JSON500 != nil { - val, err := json.Marshal(resp.JSON500) - if err != nil { - return nil, err - } - slog.Error("llm-generate container returned 500", slog.String("err", string(val))) - return nil, errors.New("llm-generate container returned 500") - } - - return resp.JSON200, nil + return w.handleNonStreamingResponse(resp) } func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { @@ -435,3 +426,94 @@ func (w *Worker) returnContainer(rc *RunnerContainer) { // Noop because we allow concurrent in-flight requests for external containers } } + +func (w *Worker) handleNonStreamingResponse(resp *LlmGenerateResponse) (*LlmResponse, error) { + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 400", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 400") + } + + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 401", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 401") + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("llm-generate container returned 500", slog.String("err", string(val))) + return nil, errors.New("llm-generate container returned 500") + } + + return resp.JSON200, nil +} + +type LlmStreamChunk struct { + Chunk string `json:"chunk,omitempty"` + TokensUsed int `json:"tokens_used,omitempty"` + Done bool `json:"done,omitempty"` +} + +func (w *Worker) handleStreamingResponse(ctx context.Context, resp *LlmGenerateResponse) (<-chan LlmStreamChunk, error) { + if resp.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode()) + } + + outputChan := make(chan LlmStreamChunk, 10) + + go func() { + defer close(outputChan) + + reader := bufio.NewReader(bytes.NewReader(resp.Body)) + totalTokens := 0 + + for { + select { + case <-ctx.Done(): + return + default: + line, err := reader.ReadBytes('\n') + if err != nil { + if err != io.EOF { + slog.Error("Error reading stream", slog.String("err", err.Error())) + } + return + } + + if bytes.HasPrefix(line, []byte("data: ")) { + data := bytes.TrimPrefix(line, []byte("data: ")) + if string(data) == "[DONE]" { + outputChan <- LlmStreamChunk{Chunk: "[DONE]", Done: true, TokensUsed: totalTokens} + return + } + + var streamData LlmStreamChunk + if err := json.Unmarshal(data, &streamData); err != nil { + slog.Error("Error unmarshaling stream data", slog.String("err", err.Error())) + continue + } + + totalTokens += streamData.TokensUsed + + select { + case outputChan <- streamData: + case <-ctx.Done(): + return + } + } + } + } + }() + + return outputChan, nil +} From a11391f4796f46a2764cb5b3670b57357e0ac900 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Sat, 3 Aug 2024 16:52:19 +0200 Subject: [PATCH 06/10] Load LLM model distributed over multiple GPUs --- runner/app/pipelines/llm_generate.py | 80 +++++++++++++++++----------- runner/dl_checkpoints.sh | 2 +- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py index 93025e0a..ac9b6018 100644 --- a/runner/app/pipelines/llm_generate.py +++ b/runner/app/pipelines/llm_generate.py @@ -1,35 +1,39 @@ import asyncio import logging import os -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, AsyncGenerator, Union import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +from accelerate import init_empty_weights, load_checkpoint_and_dispatch from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device -from huggingface_hub import file_download, hf_hub_download +from huggingface_hub import file_download, snapshot_download from threading import Thread -from typing import AsyncGenerator, Union, Dict, Any, Optional, List logger = logging.getLogger(__name__) - class LLMGeneratePipeline(Pipeline): def __init__(self, model_id: str): self.model_id = model_id kwargs = { - "cache_dir": get_model_dir() + "cache_dir": get_model_dir(), + "local_files_only": True } self.device = get_torch_device() - folder_name = file_download.repo_folder_name( - repo_id=model_id, repo_type="model" - ) - folder_path = os.path.join(get_model_dir(), folder_name) + + # Generate the correct folder name + folder_path = file_download.repo_folder_name(repo_id=model_id, repo_type="model") + self.local_model_path = os.path.join(get_model_dir(), folder_path) + self.checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True) + + logger.info(f"Local model path: {self.local_model_path}") + logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") # Check for fp16 variant has_fp16_variant = any( ".fp16.safetensors" in fname - for _, _, files in os.walk(folder_path) + for _, _, files in os.walk(self.local_model_path) for fname in files ) if self.device != "cpu" and has_fp16_variant: @@ -39,12 +43,36 @@ def __init__(self, model_id: str): elif self.device != "cpu": kwargs["torch_dtype"] = torch.bfloat16 - # Add device mapping - kwargs["device_map"] = "auto" - logger.info(f"Loading model {model_id}") self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + + # Load the model configuration + config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config + + # Initialize empty weights + with init_empty_weights(): + self.model = AutoModelForCausalLM.from_config(config) + + # Prepare for distributed setup + num_gpus = torch.cuda.device_count() + max_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} + max_memory["cpu"] = "24GiB" # Adjust based on your system's RAM + + logger.info(f"Max memory configuration: {max_memory}") + + # Load and dispatch the model + self.model = load_checkpoint_and_dispatch( + self.model, + self.checkpoint_dir, + device_map="auto", + max_memory=max_memory, + no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture + dtype=kwargs.get("torch_dtype", torch.float32), + offload_folder="offload", # Optional: specify a folder for offloading + offload_state_dict=True, # Optional: offload state dict to CPU + ) + + logger.info(f"Model loaded and distributed. Device map: {self.model.hf_device_map}") # Set up generation config self.generation_config = self.model.generation_config @@ -70,39 +98,29 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys conversation.append({"role": "system", "content": system_msg}) if history: for user, assistant in history: - conversation.extend([{"role": "user", "content": user}, { - "role": "assistant", "content": assistant}]) + conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": prompt}) - input_ids = self.tokenizer.apply_chat_template( - conversation, return_tensors="pt").to(self.model.device) + input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device) attention_mask = torch.ones_like(input_ids) max_new_tokens = kwargs.get("max_tokens", 256) temperature = kwargs.get("temperature", 0.7) - streamer = TextIteratorStreamer( - self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) - # Start with the generation config generate_kwargs = self.generation_config.to_dict() - # Update with our specific parameters generate_kwargs.update({ "input_ids": input_ids, "attention_mask": attention_mask, "streamer": streamer, "max_new_tokens": max_new_tokens, - "do_sample": True, + "do_sample": temperature > 0, "temperature": temperature, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.eos_token_id, }) - # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash. - if temperature == 0: - generate_kwargs['do_sample'] = False - - # Start generation in a separate thread thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) thread.start() @@ -119,9 +137,6 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys input_length = input_ids.size(1) yield {"tokens_used": input_length + total_tokens} - def __str__(self): - return f"LLMGeneratePipeline(model_id={self.model_id})" - def model_generate_wrapper(self, **kwargs): try: logger.debug("Entering model.generate") @@ -131,3 +146,6 @@ def model_generate_wrapper(self, **kwargs): except Exception as e: logger.error(f"Error in model.generate: {str(e)}", exc_info=True) raise + + def __str__(self): + return f"LLMGeneratePipeline(model_id={self.model_id})" \ No newline at end of file diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 8af42a54..1ac90cd8 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -62,7 +62,7 @@ function download_all_models() { huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models # Download LLM models (Warning: large model size) - huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "*.json" "*.bin" "*.safetensors" "*.txt" --cache-dir models } From 87bfe3f909eb2b5809db774759049e514ef240de Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Mon, 5 Aug 2024 17:28:49 +0200 Subject: [PATCH 07/10] feat: support 8bit and fp16 for llm pipeline --- runner/app/pipelines/llm_generate.py | 124 +++++++++++++++++---------- runner/requirements.txt | 2 + 2 files changed, 82 insertions(+), 44 deletions(-) diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py index ac9b6018..3606fe0f 100644 --- a/runner/app/pipelines/llm_generate.py +++ b/runner/app/pipelines/llm_generate.py @@ -1,10 +1,11 @@ import asyncio import logging import os +import psutil from typing import Dict, Any, List, Optional, AsyncGenerator, Union import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig from accelerate import init_empty_weights, load_checkpoint_and_dispatch from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device @@ -13,15 +14,84 @@ logger = logging.getLogger(__name__) +def get_max_memory(): + num_gpus = torch.cuda.device_count() + gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} + cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" + max_memory = {**gpu_memory, "cpu": cpu_memory} + + logger.info(f"Max memory configuration: {max_memory}") + return max_memory + +def load_model_8bit(model_id: str, **kwargs): + max_memory = get_max_memory() + + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + + model = AutoModelForCausalLM.from_pretrained( + model_id, + quantization_config=quantization_config, + device_map="auto", + max_memory=max_memory, + offload_folder="offload", + low_cpu_mem_usage=True, + **kwargs + ) + + return tokenizer, model + +def load_model_fp16(model_id: str, **kwargs): + device = get_torch_device() + max_memory = get_max_memory() + + # Check for fp16 variant + local_model_path = os.path.join(get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) + has_fp16_variant = any(".fp16.safetensors" in fname for _, _, files in os.walk(local_model_path) for fname in files) + + if device != "cpu" and has_fp16_variant: + logger.info("Loading fp16 variant for %s", model_id) + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + elif device != "cpu": + kwargs["torch_dtype"] = torch.bfloat16 + + tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) + + config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True) + + model = load_checkpoint_and_dispatch( + model, + checkpoint_dir, + device_map="auto", + max_memory=max_memory, + no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture + dtype=kwargs.get("torch_dtype", torch.float32), + offload_folder="offload", + offload_state_dict=True, + ) + + return tokenizer, model + class LLMGeneratePipeline(Pipeline): def __init__(self, model_id: str): self.model_id = model_id kwargs = { "cache_dir": get_model_dir(), - "local_files_only": True + "local_files_only": True, } self.device = get_torch_device() - + # Generate the correct folder name folder_path = file_download.repo_folder_name(repo_id=model_id, repo_type="model") self.local_model_path = os.path.join(get_model_dir(), folder_path) @@ -30,47 +100,14 @@ def __init__(self, model_id: str): logger.info(f"Local model path: {self.local_model_path}") logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") - # Check for fp16 variant - has_fp16_variant = any( - ".fp16.safetensors" in fname - for _, _, files in os.walk(self.local_model_path) - for fname in files - ) - if self.device != "cpu" and has_fp16_variant: - logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) - kwargs["torch_dtype"] = torch.float16 - kwargs["variant"] = "fp16" - elif self.device != "cpu": - kwargs["torch_dtype"] = torch.bfloat16 - - logger.info(f"Loading model {model_id}") - self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - # Load the model configuration - config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config + use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" - # Initialize empty weights - with init_empty_weights(): - self.model = AutoModelForCausalLM.from_config(config) - - # Prepare for distributed setup - num_gpus = torch.cuda.device_count() - max_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} - max_memory["cpu"] = "24GiB" # Adjust based on your system's RAM - - logger.info(f"Max memory configuration: {max_memory}") - - # Load and dispatch the model - self.model = load_checkpoint_and_dispatch( - self.model, - self.checkpoint_dir, - device_map="auto", - max_memory=max_memory, - no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture - dtype=kwargs.get("torch_dtype", torch.float32), - offload_folder="offload", # Optional: specify a folder for offloading - offload_state_dict=True, # Optional: offload state dict to CPU - ) + if use_8bit: + logger.info("Using 8-bit quantization") + self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) + else: + logger.info("Using fp16/bf16 precision") + self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) logger.info(f"Model loaded and distributed. Device map: {self.model.hf_device_map}") @@ -91,7 +128,6 @@ def __init__(self, model_id: str): ) from app.pipelines.optim.sfast import compile_model self.model = compile_model(self.model) - async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: conversation = [] if system_msg: diff --git a/runner/requirements.txt b/runner/requirements.txt index c0ae34b1..7ade3436 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -17,3 +17,5 @@ numpy==1.26.4 av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 +bitsandbytes==0.43.3 +psutil==6.0.0 From 468d65dca834eac687bf29d9db96f3fc5cd1464a Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Tue, 6 Aug 2024 04:15:36 +0200 Subject: [PATCH 08/10] fix streaming and full multipart body for llm --- worker/multipart.go | 12 ++++++++++++ worker/worker.go | 47 ++++++++++++++++++++++----------------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/worker/multipart.go b/worker/multipart.go index 7a87bcae..09d2d8a4 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -248,6 +248,12 @@ func NewLlmGenerateMultipartWriter(w io.Writer, req BodyLlmGenerateLlmGeneratePo return nil, fmt.Errorf("failed to write prompt field: %w", err) } + if req.History != nil { + if err := mw.WriteField("history", *req.History); err != nil { + return nil, fmt.Errorf("failed to write history field: %w", err) + } + } + if req.ModelId != nil { if err := mw.WriteField("model_id", *req.ModelId); err != nil { return nil, fmt.Errorf("failed to write model_id field: %w", err) @@ -272,6 +278,12 @@ func NewLlmGenerateMultipartWriter(w io.Writer, req BodyLlmGenerateLlmGeneratePo } } + if req.Stream != nil { + if err := mw.WriteField("stream", fmt.Sprintf("%v", *req.Stream)); err != nil { + return nil, fmt.Errorf("failed to write stream field: %w", err) + } + } + if err := mw.Close(); err != nil { return nil, fmt.Errorf("failed to close multipart writer: %w", err) } diff --git a/worker/worker.go b/worker/worker.go index 41f91610..cda0faf8 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -7,10 +7,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "log/slog" "net/http" "strconv" + "strings" "sync" ) @@ -309,7 +309,6 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques } func (w *Worker) LlmGenerate(ctx context.Context, req LlmGenerateFormdataRequestBody) (interface{}, error) { - slog.Info("Incoming request %v", req) c, err := w.borrowContainer(ctx, "llm-generate", *req.ModelId) if err != nil { return nil, err @@ -331,15 +330,18 @@ func (w *Worker) LlmGenerate(ctx context.Context, req LlmGenerateFormdataRequest return nil, err } - resp, err := c.Client.LlmGenerateWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) - if err != nil { - return nil, err - } - if req.Stream != nil && *req.Stream { + resp, err := c.Client.LlmGenerateWithBody(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } return w.handleStreamingResponse(ctx, resp) } + resp, err := c.Client.LlmGenerateWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } return w.handleNonStreamingResponse(resp) } @@ -464,9 +466,9 @@ type LlmStreamChunk struct { Done bool `json:"done,omitempty"` } -func (w *Worker) handleStreamingResponse(ctx context.Context, resp *LlmGenerateResponse) (<-chan LlmStreamChunk, error) { - if resp.StatusCode() != http.StatusOK { - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode()) +func (w *Worker) handleStreamingResponse(ctx context.Context, resp *http.Response) (<-chan LlmStreamChunk, error) { + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } outputChan := make(chan LlmStreamChunk, 10) @@ -474,31 +476,24 @@ func (w *Worker) handleStreamingResponse(ctx context.Context, resp *LlmGenerateR go func() { defer close(outputChan) - reader := bufio.NewReader(bytes.NewReader(resp.Body)) + scanner := bufio.NewScanner(resp.Body) totalTokens := 0 - for { + for scanner.Scan() { select { case <-ctx.Done(): return default: - line, err := reader.ReadBytes('\n') - if err != nil { - if err != io.EOF { - slog.Error("Error reading stream", slog.String("err", err.Error())) - } - return - } - - if bytes.HasPrefix(line, []byte("data: ")) { - data := bytes.TrimPrefix(line, []byte("data: ")) - if string(data) == "[DONE]" { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { outputChan <- LlmStreamChunk{Chunk: "[DONE]", Done: true, TokensUsed: totalTokens} return } var streamData LlmStreamChunk - if err := json.Unmarshal(data, &streamData); err != nil { + if err := json.Unmarshal([]byte(data), &streamData); err != nil { slog.Error("Error unmarshaling stream data", slog.String("err", err.Error())) continue } @@ -513,6 +508,10 @@ func (w *Worker) handleStreamingResponse(ctx context.Context, resp *LlmGenerateR } } } + + if err := scanner.Err(); err != nil { + slog.Error("Error reading stream", slog.String("err", err.Error())) + } }() return outputChan, nil From 401dbbec05aa87f86764aa15a89c30182c09a532 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 7 Aug 2024 01:28:37 +0200 Subject: [PATCH 09/10] fix history parsing --- runner/app/pipelines/llm_generate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py index 3606fe0f..a893c650 100644 --- a/runner/app/pipelines/llm_generate.py +++ b/runner/app/pipelines/llm_generate.py @@ -133,8 +133,7 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys if system_msg: conversation.append({"role": "system", "content": system_msg}) if history: - for user, assistant in history: - conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) + conversation.extend(history) conversation.append({"role": "user", "content": prompt}) input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device) From 35fd2b877bfc14655a16530437221999fdedec2b Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Sun, 25 Aug 2024 21:35:29 +0200 Subject: [PATCH 10/10] llm: LoRa support --- runner/app/pipelines/llm_generate.py | 79 +++++++++++++++++++++++----- runner/app/routes/llm_generate.py | 24 +++++++-- runner/requirements.txt | 2 +- 3 files changed, 88 insertions(+), 17 deletions(-) diff --git a/runner/app/pipelines/llm_generate.py b/runner/app/pipelines/llm_generate.py index a893c650..6a353d17 100644 --- a/runner/app/pipelines/llm_generate.py +++ b/runner/app/pipelines/llm_generate.py @@ -1,16 +1,21 @@ import asyncio +import base64 +import io import logging import os import psutil from typing import Dict, Any, List, Optional, AsyncGenerator, Union +from queue import Queue +from threading import Thread import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from peft import PeftModel, PeftConfig +from safetensors.torch import load_file as safe_load_file from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device from huggingface_hub import file_download, snapshot_download -from threading import Thread logger = logging.getLogger(__name__) @@ -50,7 +55,6 @@ def load_model_fp16(model_id: str, **kwargs): device = get_torch_device() max_memory = get_max_memory() - # Check for fp16 variant local_model_path = os.path.join(get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) has_fp16_variant = any(".fp16.safetensors" in fname for _, _, files in os.walk(local_model_path) for fname in files) @@ -92,7 +96,6 @@ def __init__(self, model_id: str): } self.device = get_torch_device() - # Generate the correct folder name folder_path = file_download.repo_folder_name(repo_id=model_id, repo_type="model") self.local_model_path = os.path.join(get_model_dir(), folder_path) self.checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True) @@ -111,7 +114,6 @@ def __init__(self, model_id: str): logger.info(f"Model loaded and distributed. Device map: {self.model.hf_device_map}") - # Set up generation config self.generation_config = self.model.generation_config self.terminators = [ @@ -119,7 +121,6 @@ def __init__(self, model_id: str): self.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] - # Optional: Add optimizations sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" if sfast_enabled: logger.info( @@ -128,15 +129,55 @@ def __init__(self, model_id: str): ) from app.pipelines.optim.sfast import compile_model self.model = compile_model(self.model) - async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + + self.request_queue = Queue() + self.worker_thread = Thread(target=self._process_queue, daemon=True) + self.worker_thread.start() + + def _process_queue(self): + while True: + task = self.request_queue.get() + if task is None: + break + self._process_task(*task) + + def _process_task(self, prompt, history, system_msg, lora_weights, future, **kwargs): + try: + if lora_weights: + model = self.apply_lora(self.model, lora_weights) + else: + model = self.model + + result = self._generate(model, prompt, history, system_msg, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + finally: + if lora_weights: + self.model.delete_adapters() + torch.cuda.empty_cache() + + def apply_lora(self, model, lora_weights: str): + try: + lora_bytes = base64.b64decode(lora_weights) + with io.BytesIO(lora_bytes) as f: + lora_state_dict = safe_load_file(f) + model = PeftModel.from_pretrained(model, lora_state_dict, is_trainable=False) + logger.info("Applied LoRA weights") + return model + except Exception as e: + logger.error(f"Error applying LoRA weights: {str(e)}") + raise + + def _generate(self, model, prompt, history, system_msg, **kwargs): conversation = [] if system_msg: conversation.append({"role": "system", "content": system_msg}) if history: - conversation.extend(history) + conversation.extend(history) conversation.append({"role": "user", "content": prompt}) - input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device) + input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) attention_mask = torch.ones_like(input_ids) max_new_tokens = kwargs.get("max_tokens", 256) @@ -156,7 +197,7 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys "pad_token_id": self.tokenizer.eos_token_id, }) - thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) + thread = Thread(target=self.model_generate_wrapper, args=(model,), kwargs=generate_kwargs) thread.start() total_tokens = 0 @@ -164,23 +205,37 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys for text in streamer: total_tokens += 1 yield text - await asyncio.sleep(0) # Allow other tasks to run except Exception as e: logger.error(f"Error during streaming: {str(e)}") raise + finally: + thread.join() input_length = input_ids.size(1) yield {"tokens_used": input_length + total_tokens} - def model_generate_wrapper(self, **kwargs): + def model_generate_wrapper(self, model, **kwargs): try: logger.debug("Entering model.generate") with torch.cuda.amp.autocast(): # Use automatic mixed precision - self.model.generate(**kwargs) + model.generate(**kwargs) logger.debug("Exiting model.generate") except Exception as e: logger.error(f"Error in model.generate: {str(e)}", exc_info=True) raise + async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, lora_weights: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + loop = asyncio.get_running_loop() + future = loop.create_future() + self.request_queue.put((prompt, history, system_msg, lora_weights, future, kwargs)) + result = await future + + async for item in result: + yield item + + def cleanup(self): + self.request_queue.put(None) + self.worker_thread.join() + def __str__(self): return f"LLMGeneratePipeline(model_id={self.model_id})" \ No newline at end of file diff --git a/runner/app/routes/llm_generate.py b/runner/app/routes/llm_generate.py index 07e70ff8..a39dd711 100644 --- a/runner/app/routes/llm_generate.py +++ b/runner/app/routes/llm_generate.py @@ -1,12 +1,12 @@ import logging import os -from typing import Annotated, Optional, List -from fastapi import APIRouter, Depends, Form, status, Request +from typing import Annotated +from fastapi import APIRouter, Depends, Form, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error +from app.routes.util import HTTPError, LlmResponse, http_error import json router = APIRouter() @@ -32,6 +32,7 @@ async def llm_generate( max_tokens: Annotated[int, Form()] = 256, history: Annotated[str, Form()] = "[]", # We'll parse this as JSON stream: Annotated[bool, Form()] = False, + lora_weights: Annotated[str, Form()] = None, # New parameter for LoRA weights pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -58,18 +59,33 @@ async def llm_generate( if not isinstance(history_list, list): raise ValueError("History must be a JSON array") + # Validate LoRA weights if provided + if lora_weights: + try: + # Attempt to decode base64 string + import base64 + base64.b64decode(lora_weights) + except: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + "Invalid LoRA weights format. Must be a valid base64 string.") + ) + generator = pipeline( prompt=prompt, history=history_list, system_msg=system_msg if system_msg else None, temperature=temperature, - max_tokens=max_tokens + max_tokens=max_tokens, + lora_weights=lora_weights # Pass LoRA weights to the pipeline ) if stream: return StreamingResponse(stream_generator(generator), media_type="text/event-stream") else: full_response = "" + tokens_used = 0 async for chunk in generator: if isinstance(chunk, dict): tokens_used = chunk["tokens_used"] diff --git a/runner/requirements.txt b/runner/requirements.txt index 7ade3436..ef1102b2 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -15,7 +15,7 @@ safetensors==0.4.3 scipy==1.13.0 numpy==1.26.4 av==12.1.0 -sentencepiece== 0.2.0 +sentencepiece==0.2.0 protobuf==5.27.2 bitsandbytes==0.43.3 psutil==6.0.0