-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2d52b95
commit 2821b8d
Showing
14 changed files
with
2,408 additions
and
6 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .client import Client | ||
from .error import CohereError | ||
from .mode import Mode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
from cohere_aws.response import CohereObject | ||
from cohere_aws.error import CohereError | ||
from cohere_aws.mode import Mode | ||
from typing import List, Optional, Generator, Dict, Any, Union | ||
from enum import Enum | ||
import json | ||
|
||
# Tools | ||
|
||
class ToolParameterDefinitionsValue(CohereObject, dict): | ||
def __init__( | ||
self, | ||
type: str, | ||
description: str, | ||
required: Optional[bool] = None, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.__dict__ = self | ||
self.type = type | ||
self.description = description | ||
if required is not None: | ||
self.required = required | ||
|
||
|
||
class Tool(CohereObject, dict): | ||
def __init__( | ||
self, | ||
name: str, | ||
description: str, | ||
parameter_definitions: Optional[Dict[str, ToolParameterDefinitionsValue]] = None, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.__dict__ = self | ||
self.name = name | ||
self.description = description | ||
if parameter_definitions is not None: | ||
self.parameter_definitions = parameter_definitions | ||
|
||
|
||
class ToolCall(CohereObject, dict): | ||
def __init__( | ||
self, | ||
name: str, | ||
parameters: Dict[str, Any], | ||
generation_id: str, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.__dict__ = self | ||
self.name = name | ||
self.parameters = parameters | ||
self.generation_id = generation_id | ||
|
||
@classmethod | ||
def from_dict(cls, tool_call_res: Dict[str, Any]) -> "ToolCall": | ||
return cls( | ||
name=tool_call_res.get("name"), | ||
parameters=tool_call_res.get("parameters"), | ||
generation_id=tool_call_res.get("generation_id"), | ||
) | ||
|
||
@classmethod | ||
def from_list(cls, tool_calls_res: Optional[List[Dict[str, Any]]]) -> Optional[List["ToolCall"]]: | ||
if tool_calls_res is None or not isinstance(tool_calls_res, list): | ||
return None | ||
|
||
return [ToolCall.from_dict(tc) for tc in tool_calls_res] | ||
|
||
# Chat | ||
|
||
class Chat(CohereObject): | ||
def __init__( | ||
self, | ||
response_id: str, | ||
generation_id: str, | ||
text: str, | ||
chat_history: Optional[List[Dict[str, Any]]] = None, | ||
preamble: Optional[str] = None, | ||
finish_reason: Optional[str] = None, | ||
token_count: Optional[Dict[str, int]] = None, | ||
tool_calls: Optional[List[ToolCall]] = None, | ||
citations: Optional[List[Dict[str, Any]]] = None, | ||
documents: Optional[List[Dict[str, Any]]] = None, | ||
search_results: Optional[List[Dict[str, Any]]] = None, | ||
search_queries: Optional[List[Dict[str, Any]]] = None, | ||
is_search_required: Optional[bool] = None, | ||
) -> None: | ||
self.response_id = response_id | ||
self.generation_id = generation_id | ||
self.text = text | ||
self.chat_history = chat_history | ||
self.preamble = preamble | ||
self.finish_reason = finish_reason | ||
self.token_count = token_count | ||
self.tool_calls = tool_calls | ||
self.citations = citations | ||
self.documents = documents | ||
self.search_results = search_results | ||
self.search_queries = search_queries | ||
self.is_search_required = is_search_required | ||
|
||
@classmethod | ||
def from_dict(cls, response: Dict[str, Any]) -> "Chat": | ||
return cls( | ||
response_id=response["response_id"], | ||
generation_id=response.get("generation_id"), # optional | ||
text=response.get("text"), | ||
chat_history=response.get("chat_history"), # optional | ||
preamble=response.get("preamble"), # optional | ||
token_count=response.get("token_count"), | ||
is_search_required=response.get("is_search_required"), # optional | ||
citations=response.get("citations"), # optional | ||
documents=response.get("documents"), # optional | ||
search_results=response.get("search_results"), # optional | ||
search_queries=response.get("search_queries"), # optional | ||
finish_reason=response.get("finish_reason"), | ||
tool_calls=ToolCall.from_list(response.get("tool_calls")), # optional | ||
) | ||
|
||
# ---------------| | ||
# Steaming event | | ||
# ---------------| | ||
|
||
class StreamEvent(str, Enum): | ||
STREAM_START = "stream-start" | ||
SEARCH_QUERIES_GENERATION = "search-queries-generation" | ||
SEARCH_RESULTS = "search-results" | ||
TEXT_GENERATION = "text-generation" | ||
TOOL_CALLS_GENERATION = "tool-calls-generation" | ||
CITATION_GENERATION = "citation-generation" | ||
STREAM_END = "stream-end" | ||
|
||
class StreamResponse(CohereObject): | ||
def __init__( | ||
self, | ||
is_finished: bool, | ||
event_type: Union[StreamEvent, str], | ||
index: Optional[int], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.is_finished = is_finished | ||
self.index = index | ||
self.event_type = event_type | ||
|
||
|
||
class StreamStart(StreamResponse): | ||
def __init__( | ||
self, | ||
generation_id: str, | ||
conversation_id: Optional[str], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.generation_id = generation_id | ||
self.conversation_id = conversation_id | ||
|
||
|
||
class StreamTextGeneration(StreamResponse): | ||
def __init__( | ||
self, | ||
text: str, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.text = text | ||
|
||
|
||
class StreamCitationGeneration(StreamResponse): | ||
def __init__( | ||
self, | ||
citations: Optional[List[Dict[str, Any]]], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.citations = citations | ||
|
||
|
||
class StreamQueryGeneration(StreamResponse): | ||
def __init__( | ||
self, | ||
search_queries: Optional[List[Dict[str, Any]]], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.search_queries = search_queries | ||
|
||
|
||
class StreamSearchResults(StreamResponse): | ||
def __init__( | ||
self, | ||
search_results: Optional[List[Dict[str, Any]]], | ||
documents: Optional[List[Dict[str, Any]]], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.search_results = search_results | ||
self.documents = documents | ||
|
||
|
||
class StreamEnd(StreamResponse): | ||
def __init__( | ||
self, | ||
finish_reason: str, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.finish_reason = finish_reason | ||
|
||
|
||
class ChatToolCallsGenerationEvent(StreamResponse): | ||
def __init__( | ||
self, | ||
tool_calls: Optional[List[ToolCall]], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.tool_calls = tool_calls | ||
|
||
class StreamingChat(CohereObject): | ||
def __init__(self, stream_response, mode): | ||
self.stream_response = stream_response | ||
self.text = None | ||
self.response_id = None | ||
self.generation_id = None | ||
self.preamble = None | ||
self.prompt = None | ||
self.chat_history = None | ||
self.finish_reason = None | ||
self.token_count = None | ||
self.is_search_required = None | ||
self.citations = None | ||
self.documents = None | ||
self.search_results = None | ||
self.search_queries = None | ||
self.tool_calls = None | ||
|
||
self.bytes = bytearray() | ||
if mode == Mode.SAGEMAKER: | ||
self.payload_key = "PayloadPart" | ||
self.bytes_key = "Bytes" | ||
elif mode == Mode.BEDROCK: | ||
self.payload_key = "chunk" | ||
self.bytes_key = "bytes" | ||
|
||
def _make_response_item(self, index, streaming_item) -> Any: | ||
event_type = streaming_item.get("event_type") | ||
|
||
if event_type == StreamEvent.STREAM_START: | ||
self.conversation_id = streaming_item.get("conversation_id") | ||
self.generation_id = streaming_item.get("generation_id") | ||
return StreamStart( | ||
conversation_id=self.conversation_id, | ||
generation_id=self.generation_id, | ||
is_finished=False, | ||
event_type=event_type, | ||
index=index, | ||
) | ||
elif event_type == StreamEvent.SEARCH_QUERIES_GENERATION: | ||
search_queries = streaming_item.get("search_queries") | ||
return StreamQueryGeneration( | ||
search_queries=search_queries, is_finished=False, event_type=event_type, index=index | ||
) | ||
elif event_type == StreamEvent.SEARCH_RESULTS: | ||
search_results = streaming_item.get("search_results") | ||
documents = streaming_item.get("documents") | ||
return StreamSearchResults( | ||
search_results=search_results, | ||
documents=documents, | ||
is_finished=False, | ||
event_type=event_type, | ||
index=index, | ||
) | ||
elif event_type == StreamEvent.TEXT_GENERATION: | ||
text = streaming_item.get("text") | ||
return StreamTextGeneration(text=text, is_finished=False, event_type=event_type, index=index) | ||
elif event_type == StreamEvent.CITATION_GENERATION: | ||
citations = streaming_item.get("citations") | ||
return StreamCitationGeneration(citations=citations, is_finished=False, event_type=event_type, index=index) | ||
elif event_type == StreamEvent.TOOL_CALLS_GENERATION: | ||
tool_calls = ToolCall.from_list(streaming_item.get("tool_calls")) | ||
return ChatToolCallsGenerationEvent( | ||
tool_calls=tool_calls, is_finished=False, event_type=event_type, index=index | ||
) | ||
elif event_type == StreamEvent.STREAM_END: | ||
response = streaming_item.get("response") | ||
finish_reason = streaming_item.get("finish_reason") | ||
self.finish_reason = finish_reason | ||
|
||
if response is None: | ||
return None | ||
|
||
self.response_id = response.get("response_id") | ||
self.conversation_id = response.get("conversation_id") | ||
self.text = response.get("text") | ||
self.generation_id = response.get("generation_id") | ||
self.preamble = response.get("preamble") | ||
self.prompt = response.get("prompt") | ||
self.chat_history = response.get("chat_history") | ||
self.token_count = response.get("token_count") | ||
self.is_search_required = response.get("is_search_required") # optional | ||
self.citations = response.get("citations") # optional | ||
self.documents = response.get("documents") # optional | ||
self.search_results = response.get("search_results") # optional | ||
self.search_queries = response.get("search_queries") # optional | ||
self.tool_calls = ToolCall.from_list(response.get("tool_calls")) # optional | ||
return StreamEnd(finish_reason=finish_reason, is_finished=True, event_type=event_type, index=index) | ||
return None | ||
|
||
def __iter__(self) -> Generator[StreamResponse, None, None]: | ||
index = 0 | ||
for payload in self.stream_response: | ||
self.bytes.extend(payload[self.payload_key][self.bytes_key]) | ||
try: | ||
item = self._make_response_item(index, json.loads(self.bytes)) | ||
except json.decoder.JSONDecodeError: | ||
# payload contained only a partion JSON object | ||
continue | ||
|
||
self.bytes = bytearray() | ||
if item is not None: | ||
index += 1 | ||
yield item |
60 changes: 60 additions & 0 deletions
60
src/cohere/manually_maintained/cohere_aws/classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from cohere_aws.response import CohereObject | ||
from typing import Any, Dict, Iterator, List, Literal, Union | ||
|
||
Prediction = Union[str, int, List[str], List[int]] | ||
ClassificationDict = Dict[Literal["prediction", "confidence", "text"], Any] | ||
|
||
|
||
class Classification(CohereObject): | ||
def __init__(self, classification: Union[Prediction, ClassificationDict]) -> None: | ||
# Prediction is the old format (version 1 of classification-finetuning) | ||
# ClassificationDict is the new format (version 2 of classification-finetuning). | ||
# It also contains the original text and the labels' confidence scores of the prediction | ||
self.classification = classification | ||
|
||
def is_multilabel(self) -> bool: | ||
if isinstance(self.classification, list): | ||
return True | ||
elif isinstance(self.classification, (int, str)): | ||
return False | ||
return isinstance(self.classification["prediction"], list) | ||
|
||
@property | ||
def prediction(self) -> Prediction: | ||
if isinstance(self.classification, (list, int, str)): | ||
return self.classification | ||
return self.classification["prediction"] | ||
|
||
@property | ||
def confidence(self) -> List[float]: | ||
if isinstance(self.classification, (list, int, str)): | ||
raise ValueError( | ||
"Confidence scores are not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" | ||
) | ||
return self.classification["confidence"] | ||
|
||
@property | ||
def text(self) -> str: | ||
if isinstance(self.classification, (list, int, str)): | ||
raise ValueError( | ||
"Original text is not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" | ||
) | ||
return self.classification["text"] | ||
|
||
|
||
class Classifications(CohereObject): | ||
def __init__(self, classifications: List[Classification]) -> None: | ||
self.classifications = classifications | ||
if len(self.classifications) > 0: | ||
assert all( | ||
[c.is_multilabel() == self.is_multilabel() for c in self.classifications] | ||
), "All classifications must be of the same type (single-label or multi-label)" | ||
|
||
def __iter__(self) -> Iterator: | ||
return iter(self.classifications) | ||
|
||
def __len__(self) -> int: | ||
return len(self.classifications) | ||
|
||
def is_multilabel(self) -> bool: | ||
return len(self.classifications) > 0 and self.classifications[0].is_multilabel() |
Oops, something went wrong.