Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm: LoRa support #4

Open
wants to merge 10 commits into
base: llm
Choose a base branch
from
Open
7 changes: 7 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "llm-generate":
from app.pipelines.llm_generate import LLMGeneratePipeline
return LLMGeneratePipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -82,6 +85,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "llm-generate":
from app.routes import llm_generate

return llm_generate.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
241 changes: 241 additions & 0 deletions runner/app/pipelines/llm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import asyncio
import base64
import io
import logging
import os
import psutil
from typing import Dict, Any, List, Optional, AsyncGenerator, Union
from queue import Queue
from threading import Thread

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from peft import PeftModel, PeftConfig
from safetensors.torch import load_file as safe_load_file
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download, snapshot_download

logger = logging.getLogger(__name__)

def get_max_memory():
num_gpus = torch.cuda.device_count()
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"
max_memory = {**gpu_memory, "cpu": cpu_memory}

logger.info(f"Max memory configuration: {max_memory}")
return max_memory

def load_model_8bit(model_id: str, **kwargs):
max_memory = get_max_memory()

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
max_memory=max_memory,
offload_folder="offload",
low_cpu_mem_usage=True,
**kwargs
)

return tokenizer, model

def load_model_fp16(model_id: str, **kwargs):
device = get_torch_device()
max_memory = get_max_memory()

local_model_path = os.path.join(get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
has_fp16_variant = any(".fp16.safetensors" in fname for _, _, files in os.walk(local_model_path) for fname in files)

if device != "cpu" and has_fp16_variant:
logger.info("Loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True)

model = load_checkpoint_and_dispatch(
model,
checkpoint_dir,
device_map="auto",
max_memory=max_memory,
no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture
dtype=kwargs.get("torch_dtype", torch.float32),
offload_folder="offload",
offload_state_dict=True,
)

return tokenizer, model

class LLMGeneratePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {
"cache_dir": get_model_dir(),
"local_files_only": True,
}
self.device = get_torch_device()

folder_path = file_download.repo_folder_name(repo_id=model_id, repo_type="model")
self.local_model_path = os.path.join(get_model_dir(), folder_path)
self.checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True)

logger.info(f"Local model path: {self.local_model_path}")
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}")

use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true"

if use_8bit:
logger.info("Using 8-bit quantization")
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs)
else:
logger.info("Using fp16/bf16 precision")
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs)

logger.info(f"Model loaded and distributed. Device map: {self.model.hf_device_map}")

self.generation_config = self.model.generation_config

self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"LLMGeneratePipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)

self.request_queue = Queue()
self.worker_thread = Thread(target=self._process_queue, daemon=True)
self.worker_thread.start()

def _process_queue(self):
while True:
task = self.request_queue.get()
if task is None:
break
self._process_task(*task)

def _process_task(self, prompt, history, system_msg, lora_weights, future, **kwargs):
try:
if lora_weights:
model = self.apply_lora(self.model, lora_weights)
else:
model = self.model

result = self._generate(model, prompt, history, system_msg, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
if lora_weights:
self.model.delete_adapters()
torch.cuda.empty_cache()

def apply_lora(self, model, lora_weights: str):
try:
lora_bytes = base64.b64decode(lora_weights)
with io.BytesIO(lora_bytes) as f:
lora_state_dict = safe_load_file(f)
model = PeftModel.from_pretrained(model, lora_state_dict, is_trainable=False)
logger.info("Applied LoRA weights")
return model
except Exception as e:
logger.error(f"Error applying LoRA weights: {str(e)}")
raise

def _generate(self, model, prompt, history, system_msg, **kwargs):
conversation = []
if system_msg:
conversation.append({"role": "system", "content": system_msg})
if history:
conversation.extend(history)
conversation.append({"role": "user", "content": prompt})

input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
attention_mask = torch.ones_like(input_ids)

max_new_tokens = kwargs.get("max_tokens", 256)
temperature = kwargs.get("temperature", 0.7)

streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

generate_kwargs = self.generation_config.to_dict()
generate_kwargs.update({
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.eos_token_id,
})

thread = Thread(target=self.model_generate_wrapper, args=(model,), kwargs=generate_kwargs)
thread.start()

total_tokens = 0
try:
for text in streamer:
total_tokens += 1
yield text
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise
finally:
thread.join()

input_length = input_ids.size(1)
yield {"tokens_used": input_length + total_tokens}

def model_generate_wrapper(self, model, **kwargs):
try:
logger.debug("Entering model.generate")
with torch.cuda.amp.autocast(): # Use automatic mixed precision
model.generate(**kwargs)
logger.debug("Exiting model.generate")
except Exception as e:
logger.error(f"Error in model.generate: {str(e)}", exc_info=True)
raise

async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, lora_weights: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
loop = asyncio.get_running_loop()
future = loop.create_future()
self.request_queue.put((prompt, history, system_msg, lora_weights, future, kwargs))
result = await future

async for item in result:
yield item

def cleanup(self):
self.request_queue.put(None)
self.worker_thread.join()

def __str__(self):
return f"LLMGeneratePipeline(model_id={self.model_id})"
126 changes: 126 additions & 0 deletions runner/app/routes/llm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import logging
import os
from typing import Annotated
from fastapi import APIRouter, Depends, Form, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, LlmResponse, http_error
import json

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_200_OK: {"model": LlmResponse},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post("/llm-generate",
response_model=LlmResponse, responses=RESPONSES)
@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False)
async def llm_generate(
prompt: Annotated[str, Form()],
model_id: Annotated[str, Form()] = "",
system_msg: Annotated[str, Form()] = "",
temperature: Annotated[float, Form()] = 0.7,
max_tokens: Annotated[int, Form()] = 256,
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON
stream: Annotated[bool, Form()] = False,
lora_weights: Annotated[str, Form()] = None, # New parameter for LoRA weights
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

try:
history_list = json.loads(history)
if not isinstance(history_list, list):
raise ValueError("History must be a JSON array")

# Validate LoRA weights if provided
if lora_weights:
try:
# Attempt to decode base64 string
import base64
base64.b64decode(lora_weights)
except:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
"Invalid LoRA weights format. Must be a valid base64 string.")
)

generator = pipeline(
prompt=prompt,
history=history_list,
system_msg=system_msg if system_msg else None,
temperature=temperature,
max_tokens=max_tokens,
lora_weights=lora_weights # Pass LoRA weights to the pipeline
)

if stream:
return StreamingResponse(stream_generator(generator), media_type="text/event-stream")
else:
full_response = ""
tokens_used = 0
async for chunk in generator:
if isinstance(chunk, dict):
tokens_used = chunk["tokens_used"]
break
full_response += chunk

return LlmResponse(response=full_response, tokens_used=tokens_used)

except json.JSONDecodeError:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "Invalid JSON format for history"}
)
except ValueError as ve:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(ve)}
)
except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error during LLM processing."}
)


async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict): # This is the final result
yield f"data: {json.dumps(chunk)}\n\n"
break
else:
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
Loading