diff --git a/dsp/modules/ollama.py b/dsp/modules/ollama.py index 7781583be..3ea025c8a 100644 --- a/dsp/modules/ollama.py +++ b/dsp/modules/ollama.py @@ -1,6 +1,6 @@ import datetime import hashlib -from typing import Any, Literal +from typing import Any, Literal, Optional import requests @@ -25,6 +25,8 @@ class OllamaLocal(LM): model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text". base_url (str): Protocol, host name, and port to the served ollama model. Defaults to "http://localhost:11434" as in ollama docs. timeout_s (float): Timeout period (in seconds) for the post request to llm. + format (str): The format to return a response in. Currently the only accepted value is `json` + system (str): System Prompt to use when running in `text` mode. **kwargs: Additional arguments to pass to the API. """ @@ -42,6 +44,8 @@ def __init__( presence_penalty: float = 0, n: int = 1, num_ctx: int = 1024, + format: Optional[Literal["json"]] = None, + system: Optional[str] = None, **kwargs, ): super().__init__(model) @@ -51,6 +55,8 @@ def __init__( self.base_url = base_url self.model_name = model self.timeout_s = timeout_s + self.format = format + self.system = system self.kwargs = { "temperature": temperature, @@ -86,9 +92,18 @@ def basic_request(self, prompt: str, **kwargs): "options": {k: v for k, v in kwargs.items() if k not in ["n", "max_tokens"]}, "stream": False, } + + # Set the format if it was defined + if self.format: + settings_dict["format"] = self.format + if self.model_type == "chat": settings_dict["messages"] = [{"role": "user", "content": prompt}] else: + # Overwrite system prompt defined in modelfile + if self.system: + settings_dict["system"] = self.system + settings_dict["prompt"] = prompt urlstr = f"{self.base_url}/api/chat" if self.model_type == "chat" else f"{self.base_url}/api/generate"