Skip to content

Commit

Permalink
Merge pull request #1323 from LLukas22/ollama-json-mode
Browse files Browse the repository at this point in the history
feat(dspy): add system and format options to ollama
  • Loading branch information
arnavsinghvi11 authored Jul 29, 2024
2 parents f666e28 + 871eebe commit b1da1af
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion dsp/modules/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import hashlib
from typing import Any, Literal
from typing import Any, Literal, Optional

import requests

Expand All @@ -25,6 +25,8 @@ class OllamaLocal(LM):
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
base_url (str): Protocol, host name, and port to the served ollama model. Defaults to "http://localhost:11434" as in ollama docs.
timeout_s (float): Timeout period (in seconds) for the post request to llm.
format (str): The format to return a response in. Currently the only accepted value is `json`
system (str): System Prompt to use when running in `text` mode.
**kwargs: Additional arguments to pass to the API.
"""

Expand All @@ -42,6 +44,8 @@ def __init__(
presence_penalty: float = 0,
n: int = 1,
num_ctx: int = 1024,
format: Optional[Literal["json"]] = None,
system: Optional[str] = None,
**kwargs,
):
super().__init__(model)
Expand All @@ -51,6 +55,8 @@ def __init__(
self.base_url = base_url
self.model_name = model
self.timeout_s = timeout_s
self.format = format
self.system = system

self.kwargs = {
"temperature": temperature,
Expand Down Expand Up @@ -86,9 +92,18 @@ def basic_request(self, prompt: str, **kwargs):
"options": {k: v for k, v in kwargs.items() if k not in ["n", "max_tokens"]},
"stream": False,
}

# Set the format if it was defined
if self.format:
settings_dict["format"] = self.format

if self.model_type == "chat":
settings_dict["messages"] = [{"role": "user", "content": prompt}]
else:
# Overwrite system prompt defined in modelfile
if self.system:
settings_dict["system"] = self.system

settings_dict["prompt"] = prompt

urlstr = f"{self.base_url}/api/chat" if self.model_type == "chat" else f"{self.base_url}/api/generate"
Expand Down

0 comments on commit b1da1af

Please sign in to comment.