From dd7d7fa87ddba826fc9b81d12028679a98babcd7 Mon Sep 17 00:00:00 2001 From: Aber Date: Tue, 22 Oct 2024 12:09:41 +0800 Subject: [PATCH] Add WebSocket API (#2) --- pdm.lock | 33 ++++++- pyproject.toml | 1 + src/fish_audio_sdk/__init__.py | 11 ++- src/fish_audio_sdk/exceptions.py | 6 ++ src/fish_audio_sdk/schemas.py | 14 +++ src/fish_audio_sdk/websocket.py | 147 +++++++++++++++++++++++++++++++ tests/conftest.py | 12 ++- tests/test_websocket.py | 31 +++++++ 8 files changed, 252 insertions(+), 3 deletions(-) create mode 100644 src/fish_audio_sdk/websocket.py create mode 100644 tests/test_websocket.py diff --git a/pdm.lock b/pdm.lock index ee9549d..c6c1aa3 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:e56a97f3923d30a67d34db2f14b029a527adf4f0ab7e1d9f87f9027bb3f274b3" +content_hash = "sha256:983e900f536fe59fcb8b10175691ed452091c70e0fdc09c284c8c336274583d0" [[package]] name = "annotated-types" @@ -114,6 +114,23 @@ files = [ {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] +[[package]] +name = "httpx-ws" +version = "0.6.2" +requires_python = ">=3.8" +summary = "WebSockets support for HTTPX" +groups = ["default"] +dependencies = [ + "anyio>=4", + "httpcore>=1.0.4", + "httpx>=0.23.1", + "wsproto", +] +files = [ + {file = "httpx_ws-0.6.2-py3-none-any.whl", hash = "sha256:24f87427acb757ada200aeab016cc429fa0bc71b0730429c37634867194e305c"}, + {file = "httpx_ws-0.6.2.tar.gz", hash = "sha256:b07446b9067a30f1012fa9851fdfd14207012cd657c485565884f90553d0854c"}, +] + [[package]] name = "idna" version = "3.9" @@ -354,3 +371,17 @@ files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] + +[[package]] +name = "wsproto" +version = "1.2.0" +requires_python = ">=3.7.0" +summary = "WebSockets state-machine based protocol implementation" +groups = ["default"] +dependencies = [ + "h11<1,>=0.9.0", +] +files = [ + {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, + {file = "wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065"}, +] diff --git a/pyproject.toml b/pyproject.toml index 928c499..9ae40b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "httpx>=0.27.2", "ormsgpack>=1.5.0", "pydantic>=2.9.1", + "httpx-ws>=0.6.2", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/fish_audio_sdk/__init__.py b/src/fish_audio_sdk/__init__.py index 081fca4..3d03cf6 100644 --- a/src/fish_audio_sdk/__init__.py +++ b/src/fish_audio_sdk/__init__.py @@ -1,5 +1,14 @@ from .apis import Session from .exceptions import HttpCodeErr from .schemas import ASRRequest, TTSRequest, ReferenceAudio +from .websocket import WebSocketSession, AsyncWebSocketSession -__all__ = ["Session", "HttpCodeErr", "ReferenceAudio", "TTSRequest", "ASRRequest"] +__all__ = [ + "Session", + "HttpCodeErr", + "ReferenceAudio", + "TTSRequest", + "ASRRequest", + "WebSocketSession", + "AsyncWebSocketSession", +] diff --git a/src/fish_audio_sdk/exceptions.py b/src/fish_audio_sdk/exceptions.py index 1dda6c4..fdbcded 100644 --- a/src/fish_audio_sdk/exceptions.py +++ b/src/fish_audio_sdk/exceptions.py @@ -10,3 +10,9 @@ def __init__(self, status: int, message: str): self.status = status self.message = message super().__init__(f"{status} {message}") + + +class WebSocketErr(Exception): + """ + {"event": "finish", "reason": "error"} or WebSocketDisconnect + """ diff --git a/src/fish_audio_sdk/schemas.py b/src/fish_audio_sdk/schemas.py index c40f60d..da85ea8 100644 --- a/src/fish_audio_sdk/schemas.py +++ b/src/fish_audio_sdk/schemas.py @@ -104,3 +104,17 @@ class PackageEntity(BaseModel): created_at: str updated_at: str finished_at: str + + +class StartEvent(BaseModel): + event: Literal["start"] = "start" + request: TTSRequest + + +class TextEvent(BaseModel): + event: Literal["text"] = "text" + text: str + + +class CloseEvent(BaseModel): + event: Literal["stop"] = "stop" diff --git a/src/fish_audio_sdk/websocket.py b/src/fish_audio_sdk/websocket.py new file mode 100644 index 0000000..90b5b14 --- /dev/null +++ b/src/fish_audio_sdk/websocket.py @@ -0,0 +1,147 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import AsyncGenerator, AsyncIterable, Generator, Iterable + +import httpx +import ormsgpack +from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws + +from .exceptions import WebSocketErr + +from .schemas import CloseEvent, StartEvent, TTSRequest, TextEvent + + +class WebSocketSession: + def __init__( + self, + apikey: str, + *, + base_url: str = "https://api.fish.audio", + max_workers: int = 10, + ): + self._apikey = apikey + self._base_url = base_url + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._client = httpx.Client( + base_url=self._base_url, + headers={"Authorization": f"Bearer {self._apikey}"}, + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self): + self._client.close() + + def tts( + self, request: TTSRequest, text_stream: Iterable[str] + ) -> Generator[bytes, None, None]: + with connect_ws("/v1/tts/live", client=self._client) as ws: + + def sender(): + ws.send_bytes( + ormsgpack.packb( + StartEvent(request=request), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + ) + for text in text_stream: + ws.send_bytes( + ormsgpack.packb( + TextEvent(text=text), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + ) + ws.send_bytes( + ormsgpack.packb( + CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC + ) + ) + + sender_future = self._executor.submit(sender) + + while True: + try: + message = ws.receive_bytes() + data = ormsgpack.unpackb(message) + match data["event"]: + case "audio": + yield data["audio"] + case "finish" if data["reason"] == "error": + raise WebSocketErr + case "finish" if data["reason"] == "stop": + break + except WebSocketDisconnect: + raise WebSocketErr + + sender_future.result() + + +class AsyncWebSocketSession: + def __init__( + self, + apikey: str, + *, + base_url: str = "https://api.fish.audio", + ): + self._apikey = apikey + self._base_url = base_url + self._client = httpx.AsyncClient( + base_url=self._base_url, + headers={"Authorization": f"Bearer {self._apikey}"}, + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def close(self): + await self._client.aclose() + + async def tts( + self, request: TTSRequest, text_stream: AsyncIterable[str] + ) -> AsyncGenerator[bytes, None]: + async with aconnect_ws("/v1/tts/live", client=self._client) as ws: + + async def sender(): + await ws.send_bytes( + ormsgpack.packb( + StartEvent(request=request), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + ) + async for text in text_stream: + await ws.send_bytes( + ormsgpack.packb( + TextEvent(text=text), + option=ormsgpack.OPT_SERIALIZE_PYDANTIC, + ) + ) + await ws.send_bytes( + ormsgpack.packb( + CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC + ) + ) + + sender_future = asyncio.get_running_loop().create_task(sender()) + + while True: + try: + message = await ws.receive_bytes() + data = ormsgpack.unpackb(message) + match data["event"]: + case "audio": + yield data["audio"] + case "finish" if data["reason"] == "error": + raise WebSocketErr + case "finish" if data["reason"] == "stop": + break + except WebSocketDisconnect: + raise WebSocketErr + + await sender_future diff --git a/tests/conftest.py b/tests/conftest.py index d77fa70..17c2108 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pytest -from fish_audio_sdk.apis import Session +from fish_audio_sdk import Session, WebSocketSession, AsyncWebSocketSession APIKEY = os.environ["APIKEY"] @@ -10,3 +10,13 @@ @pytest.fixture def session(): return Session(APIKEY) + + +@pytest.fixture +def sync_websocket(): + return WebSocketSession(APIKEY) + + +@pytest.fixture +def async_websocket(): + return AsyncWebSocketSession(APIKEY) diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 0000000..bd41ba0 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,31 @@ +from fish_audio_sdk import TTSRequest, WebSocketSession, AsyncWebSocketSession + +story = """ +修炼了六千三百七十九年又三月零六天后,天门因她终于洞开。 + +她凭虚站立在黄山峰顶,因天门洞开而鼓起的飓风不停拍打着她身上的黑袍,在催促她快快登仙而去;黄山间壮阔的云海也随之翻涌,为这一场天地幸事欢呼雀跃。她没有抬头看向那似隐似现、若有若无、形态万千变化的天门,只是呆立在原处自顾自地看向远方。 +""" + + +def test_tts(sync_websocket: WebSocketSession): + buffer = bytearray() + + def stream(): + for line in story.split("\n"): + yield line + + for chunk in sync_websocket.tts(TTSRequest(text=""), stream()): + buffer.extend(chunk) + assert len(buffer) > 0 + + +async def test_async_tts(async_websocket: AsyncWebSocketSession): + buffer = bytearray() + + async def stream(): + for line in story.split("\n"): + yield line + + async for chunk in async_websocket.tts(TTSRequest(text=""), stream()): + buffer.extend(chunk) + assert len(buffer) > 0