Skip to content

Commit

Permalink
Revert augmented functionality (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
harry-cohere authored Jul 24, 2023
1 parent abf9cbd commit 38a7faa
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 123 deletions.
2 changes: 0 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
- Better string representation for DetectLanguageResponse
- [#249](https://github.com/cohere-ai/cohere-python/pull/249)
- Catch ClientPayloadError in AsyncClient and convert it to a CohereAPIError
- [#250](https://github.com/cohere-ai/cohere-python/pull/250)
- Add support for query generation and documents in chat (non-streaming)

## 4.10.0

Expand Down
21 changes: 1 addition & 20 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Tokens,
)
from cohere.responses.bulk_embed import BulkEmbedJob, CreateBulkEmbedJobResponse
from cohere.responses.chat import Chat, Mode, StreamingChat
from cohere.responses.chat import Chat, StreamingChat
from cohere.responses.classify import Example as ClassifyExample
from cohere.responses.classify import LabelPrediction
from cohere.responses.cluster import ClusterJobResult, CreateClusterJobResponse
Expand Down Expand Up @@ -220,8 +220,6 @@ def chat(
p: Optional[float] = None,
k: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
mode: Optional[Mode] = None,
documents: Optional[List[Dict[str, str]]] = None,
) -> Union[Chat, StreamingChat]:
"""Returns a Chat object with the query reply.
Expand All @@ -242,11 +240,6 @@ def chat(
p (float): (Optional) The nucleus sampling probability.
k (float): (Optional) The top-k sampling probability.
logit_bias (Dict[int, float]): (Optional) A dictionary of logit bias values to use for the next reply.
mode Mode: (Optional) This property determines which functionality of retrieval augmented generation to use.
chat mode doesn't use any retrieval augmented generation functionality.
search_query_generation uses the provided query to produce search terms that you can use to search for documents.
augmented_generation uses the provided documents and query to produce citations
document Document: (Optional) The documents to use in augmented_generation mode. Shape: ("title", str), ("snippet", str), ("url", str)
Returns:
a Chat object if stream=False, or a StreamingChat object if stream=True
Expand Down Expand Up @@ -279,16 +272,6 @@ def chat(
>>> return_prompt=True)
>>> print(res.text)
>>> print(res.prompt)
Query generation example:
>>> res = co.chat(query="What are the tallest penguins?", mode="search_query_generation")
>>> print(res.queries)
>>> print(res.is_search_required)
Augmented generation example:
>>> res = co.chat(query="What are the tallest penguins?",
mode="augmented_generation",
documents = [{"title":"Tall penguins", "snippet":"Emperor penguins are the tallest", "url":"http://example.com/foo"}])
>>> print(res.text)
>>> print(res.citations)
"""
if chat_history is not None:
should_warn = True
Expand Down Expand Up @@ -326,8 +309,6 @@ def chat(
"p": p,
"k": k,
"logit_bias": logit_bias,
"mode": mode,
"documents": documents,
}
response = self._request(cohere.CHAT_URL, json=json_body, stream=stream)

Expand Down
6 changes: 1 addition & 5 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
Tokens,
)
from cohere.responses.bulk_embed import AsyncCreateBulkEmbedJobResponse, BulkEmbedJob
from cohere.responses.chat import AsyncChat, Mode, StreamingChat
from cohere.responses.chat import AsyncChat, StreamingChat
from cohere.responses.classify import Example as ClassifyExample
from cohere.responses.custom_model import (
CUSTOM_MODEL_PRODUCT_MAPPING,
Expand Down Expand Up @@ -206,8 +206,6 @@ async def chat(
p: Optional[float] = None,
k: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
mode: Optional[Mode] = None,
documents: Optional[List[Dict[str, str]]] = None,
) -> Union[AsyncChat, StreamingChat]:
if chat_history is not None:
should_warn = True
Expand Down Expand Up @@ -248,8 +246,6 @@ async def chat(
"p": p,
"k": k,
"logit_bias": logit_bias,
"mode": mode,
"documents": documents,
}

response = await self._request(cohere.CHAT_URL, json=json_body, stream=stream)
Expand Down
32 changes: 8 additions & 24 deletions cohere/responses/chat.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,25 @@
import json
from enum import Enum
from typing import Any, Dict, Generator, List, NamedTuple, Optional

import requests

from cohere.responses.base import CohereObject


class Mode(str, Enum):
CHAT = "chat"
SEARCH_QUERY_GENERATION = "search_query_generation"
AUGMENTED_GENERATION = "augmented_generation"


class Chat(CohereObject):
def __init__(
self,
response_id: Optional[str],
generation_id: Optional[str],
message: Optional[str],
text: Optional[str],
conversation_id: Optional[str],
response_id: str,
generation_id: str,
message: str,
text: str,
conversation_id: str,
meta: Optional[Dict[str, Any]] = None,
prompt: Optional[str] = None,
chatlog: Optional[List[Dict[str, str]]] = None,
preamble: Optional[str] = None,
token_count: Optional[Dict[str, int]] = None,
client=None,
is_search_required: Optional[bool] = None,
queries: Optional[List[str]] = None,
citations: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -45,16 +35,13 @@ def __init__(
self.client = client
self.token_count = token_count
self.meta = meta
self.queries = queries
self.citations = citations
self.is_search_required = is_search_required

@classmethod
def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat":
return cls(
id=response.get("response_id"),
response_id=response.get("response_id"),
generation_id=response.get("generation_id"),
id=response["response_id"],
response_id=response["response_id"],
generation_id=response["generation_id"],
message=message,
conversation_id=response["conversation_id"],
text=response.get("text"),
Expand All @@ -64,9 +51,6 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat":
client=client,
token_count=response.get("token_count"),
meta=response.get("meta"),
queries=response.get("queries"),
is_search_required=response.get("is_search_required"),
citations=response.get("citations"),
)

def respond(self, response: str, max_tokens: int = None) -> "Chat":
Expand Down
37 changes: 0 additions & 37 deletions tests/async/test_async_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import List

import pytest

from cohere.responses.chat import Mode


@pytest.mark.asyncio
async def test_async_multi_replies(async_client):
Expand All @@ -20,39 +16,6 @@ async def test_async_multi_replies(async_client):
assert prediction.meta["api_version"]["version"]


@pytest.mark.asyncio
async def test_search_query_generation(async_client):
prediction = await async_client.chat("What are the tallest penguins?", mode="search_query_generation")
assert isinstance(prediction.is_search_required, bool)
assert isinstance(prediction.queries, List)
assert prediction.is_search_required
assert len(prediction.queries) > 0


@pytest.mark.asyncio
async def test_search_query_generation_with_enum(async_client):
prediction = await async_client.chat("What are the tallest penguins?", mode=Mode.SEARCH_QUERY_GENERATION)
assert isinstance(prediction.is_search_required, bool)
assert isinstance(prediction.queries, List)
assert prediction.is_search_required
assert len(prediction.queries) > 0


@pytest.mark.asyncio
async def test_augmented_generation(async_client):
prediction = await async_client.chat(
"What are the tallest penguins?",
mode="augmented_generation",
documents=[
{"title": "Tall penguins", "snippet": "Emperor penguins are the tallest", "url": "http://example.com/foo"}
],
)
assert isinstance(prediction.text, str)
assert isinstance(prediction.citations, List)
assert len(prediction.text) > 0
assert len(prediction.citations) > 0


@pytest.mark.asyncio
async def test_async_chat_stream(async_client):
res = await async_client.chat(
Expand Down
35 changes: 0 additions & 35 deletions tests/sync/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,38 +211,3 @@ def test_invalid_logit_bias(self):
for logit_bias in invalid:
with self.assertRaises(cohere.error.CohereError):
_ = co.chat("Yo what up?", logit_bias=logit_bias, max_tokens=5)


"""
Stop testing augmented generation while we change the API
def test_search_query_generation(self):
prediction = co.chat("What are the tallest penguins?", mode="search_query_generation")
self.assertIsInstance(prediction.is_search_required, bool)
self.assertIsInstance(prediction.queries, List)
self.assertTrue(prediction.is_search_required)
self.assertGreater(len(prediction.queries), 0)
def test_search_query_generation_with_enum(self):
prediction = co.chat("What are the tallest penguins?", mode=Mode.SEARCH_QUERY_GENERATION)
self.assertIsInstance(prediction.is_search_required, bool)
self.assertIsInstance(prediction.queries, List)
self.assertTrue(prediction.is_search_required)
self.assertGreater(len(prediction.queries), 0)
def test_augmented_generation(self):
prediction = co.chat(
"What are the tallest penguins?",
mode="augmented_generation",
documents=[
{
"title": "Tall penguins",
"snippet": "Emperor penguins are the tallest",
"url": "http://example.com/foo",
}
],
)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.citations, List)
self.assertGreater(len(prediction.text), 0)
self.assertGreater(len(prediction.citations), 0)
"""

0 comments on commit 38a7faa

Please sign in to comment.