Skip to content

Commit

Permalink
implement instruct generation
Browse files Browse the repository at this point in the history
  • Loading branch information
andre15silva committed Aug 9, 2024
1 parent cde11ca commit a136f46
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 20 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@
@dataclass
class GenerateSettings:
name: str
num_beams: int = 1
do_sample: bool = False
temperature: float = 0.0
max_new_tokens: int = 128
temperature: float = 1.0
num_beams: int = 1
num_return_sequences: int = 10
max_new_tokens: int = 1024
max_new_tokens: int = 4096


class CodeLlamaHFModels(PatchGenerationStrategy):
class CodeLLaMAInfilling(PatchGenerationStrategy):
__SUPPORTED_MODELS = {
"codellama/CodeLlama-7b-hf",
"codellama/CodeLlama-13b-hf",
"codellama/CodeLlama-7b-Instruct-hf",
"codellama/CodeLlama-13b-Instruct-hf",
"meta-llama/CodeLlama-7b-hf",
"meta-llama/CodeLlama-13b-hf",
"meta-llama/CodeLlama-34b-hf",
"meta-llama/CodeLlama-70b-hf",
}

__GENERATION_STRATEGIES = {
Expand All @@ -45,28 +44,34 @@ class CodeLlamaHFModels(PatchGenerationStrategy):
def __init__(self, model_name: str, **kwargs) -> None:
assert (
model_name in self.__SUPPORTED_MODELS
), f"Model {model_name} not supported by CodeLlamaHFModels"
), f"Model {model_name} not supported by {self.__class__.__name__}"
self.model_name = model_name
self.__load_model()
# Generation settings
assert (
kwargs.get("generation_strategy", "beam_search")
in self.__GENERATION_STRATEGIES
), f"Generation strategy {kwargs.get('generation_strategy', 'beam_search')} not supported by CodeLlamaHFModels"
), f"Generation strategy {kwargs.get('generation_strategy', 'beam_search')} not supported by {self.__class__.__name__}"
self.generate_settings = self.__GENERATION_STRATEGIES[
kwargs.get("generation_strategy", "beam_search")
]
self.generate_settings.max_new_tokens = kwargs.get("max_new_tokens", 128)
self.generate_settings.max_new_tokens = kwargs.get(
"max_new_tokens", GenerateSettings.max_new_tokens
)
self.generate_settings.num_return_sequences = kwargs.get(
"num_return_sequences", 10
"num_return_sequences", GenerateSettings.num_return_sequences
)
self.generate_settings.num_beams = kwargs.get(
"num_beams", GenerateSettings.num_beams
)
self.generate_settings.temperature = kwargs.get(
"temperature", GenerateSettings.temperature
)
self.generate_settings.num_beams = kwargs.get("num_beams", 1)
self.generate_settings.temperature = kwargs.get("temperature", 0.2)

def __load_model(self):
# Setup environment
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.context_size = 2048
self.context_size = 16384

# Setup kwargs
kwargs = dict(
Expand All @@ -81,6 +86,7 @@ def __load_model(self):
self.__MODEL = LlamaForCausalLM.from_pretrained(
self.model_name, device_map="auto", **kwargs
)
self.__MODEL.eval()
self.__MODELS_LOADED = True

def _generate_impl(self, prompt: str) -> Any:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy
from dataclasses import dataclass
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Any

import torch
import threading
import logging


@dataclass
class GenerateSettings:
name: str
do_sample: bool = False
temperature: float = 1.0
num_beams: int = 1
num_return_sequences: int = 10
max_new_tokens: int = 4096


class CodeLLaMAIntruct(PatchGenerationStrategy):
__SUPPORTED_MODELS = {
"meta-llama/CodeLlama-7b-Instruct-hf",
"meta-llama/CodeLlama-13b-Instruct-hf",
"meta-llama/CodeLlama-34b-Instruct-hf",
"meta-llama/CodeLlama-70b-Instruct-hf",
}

__GENERATION_STRATEGIES = {
"beam_search": GenerateSettings(
name="beam_search",
),
"sampling": GenerateSettings(
name="sampling",
do_sample=True,
),
}

__MODEL = None
__TOKENIZER = None
__MODELS_LOADED: bool = False
__MODELS_LOCK: threading.Lock = threading.Lock()

def __init__(self, model_name: str, **kwargs) -> None:
assert (
model_name in self.__SUPPORTED_MODELS
), f"Model {model_name} not supported by {self.__class__.__name__}"
self.model_name = model_name
self.__load_model(**kwargs)
# Generation settings
assert (
kwargs.get("generation_strategy", "sampling")
in self.__GENERATION_STRATEGIES
), f"Generation strategy {kwargs.get('generation_strategy', 'samlping')} not supported by {self.__class__.__name__}"
self.generate_settings = self.__GENERATION_STRATEGIES[
kwargs.get("generation_strategy", "samlping")
]
self.generate_settings.max_new_tokens = kwargs.get(
"max_new_tokens", GenerateSettings.max_new_tokens
)
self.generate_settings.num_return_sequences = kwargs.get(
"num_return_sequences", GenerateSettings.num_return_sequences
)
self.generate_settings.num_beams = kwargs.get(
"num_beams", GenerateSettings.num_beams
)
self.generate_settings.temperature = kwargs.get(
"temperature", GenerateSettings.temperature
)

def __load_model(self, **kwargs):
# Setup environment
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.context_size = 16384

# Setup kwargs
kwargs = dict(
torch_dtype=torch.bfloat16,
)

# Load the model and tokenizer
with self.__MODELS_LOCK:
if self.__MODELS_LOADED:
return

# Load tokenizer
self.__TOKENIZER = AutoTokenizer.from_pretrained(self.model_name)
self.__TOKENIZER.pad_token = self.__TOKENIZER.eos_token
# Load model
self.__MODEL = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", **kwargs
)
# Load LoRA adapter (if requested)
if kwargs.get("adapter_name", None) is not None:
self.__MODEL = PeftModel.from_pretrained(
self.__MODEL, kwargs.get("adapter_name")
)
self.__MODEL.merge_and_unload()
self.__MODEL.eval()
self.__MODELS_LOADED = True

def __format_prompt(self, prompt: str) -> str:
return f"<s>[INST] {prompt} [\\INST]"

def _generate_impl(self, prompt: str) -> Any:
formatted_prompt = self.__format_prompt(prompt)

input_ids = self.__TOKENIZER(formatted_prompt, return_tensors="pt")["input_ids"]
input_ids = input_ids.to(self.device)

max_length = self.generate_settings.max_new_tokens + input_ids.shape[1]
if max_length > self.context_size:
logging.warning(
"warning: max_length %s is greater than the context window %s"
% (max_length, self.context_size)
)
return None

with torch.no_grad():
generated_ids = self.__MODEL.generate(
input_ids,
max_new_tokens=self.generate_settings.max_new_tokens,
num_beams=self.generate_settings.num_beams,
num_return_sequences=self.generate_settings.num_return_sequences,
early_stopping=True,
do_sample=self.generate_settings.do_sample,
temperature=self.generate_settings.temperature,
)

input_len = input_ids.shape[1]
fillings_ids = generated_ids[:, input_len:]
fillings = self.__TOKENIZER.batch_decode(fillings_ids, skip_special_tokens=True)

return list(fillings)
29 changes: 25 additions & 4 deletions elleelleaime/generate/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from elleelleaime.generate.strategies.models.openai.openai import (
OpenAIChatCompletionModels,
)
from elleelleaime.generate.strategies.models.huggingface.codellama import (
CodeLlamaHFModels,
from elleelleaime.generate.strategies.models.huggingface.codellama.codellama_infilling import (
CodeLLaMAInfilling,
)
from elleelleaime.generate.strategies.models.huggingface.codellama.codellama_instruct import (
CodeLLaMAIntruct,
)

from typing import Tuple
Expand All @@ -21,8 +24,26 @@ class PatchGenerationStrategyRegistry:
"gpt-3.5-turbo": (OpenAIChatCompletionModels, ("gpt-3.5-turbo",)),
"gpt-4o-mini": (OpenAIChatCompletionModels, ("gpt-4o-mini",)),
# HuggingFace models
"codellama-7b": (CodeLlamaHFModels, ("codellama/CodeLlama-7b-hf",)),
"codellama-13b": (CodeLlamaHFModels, ("codellama/CodeLlama-13b-hf",)),
"codellama-7b": (CodeLLaMAInfilling, ("meta-llama/CodeLlama-7b-hf",)),
"codellama-13b": (CodeLLaMAInfilling, ("meta-llama/CodeLlama-13b-hf",)),
"codellama-34b": (CodeLLaMAInfilling, ("meta-llama/CodeLlama-34b-hf",)),
"codellama-70b": (CodeLLaMAInfilling, ("meta-llama/CodeLlama-70b-hf",)),
"codellama-7b-instruct": (
CodeLLaMAIntruct,
("meta-llama/CodeLlama-7b-Instruct-hf",),
),
"codellama-13b-instruct": (
CodeLLaMAIntruct,
("meta-llama/CodeLlama-13b-Instruct-hf",),
),
"codellama-34b-instruct": (
CodeLLaMAIntruct,
("meta-llama/CodeLlama-34b-Instruct-hf",),
),
"codellama-70b-instruct": (
CodeLLaMAIntruct,
("meta-llama/CodeLlama-70b-Instruct-hf",),
),
}

@classmethod
Expand Down

0 comments on commit a136f46

Please sign in to comment.