Skip to content

Commit

Permalink
Add WebSocket API (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
abersheeran authored Oct 22, 2024
1 parent 2d8c04b commit dd7d7fa
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 3 deletions.
33 changes: 32 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"httpx>=0.27.2",
"ormsgpack>=1.5.0",
"pydantic>=2.9.1",
"httpx-ws>=0.6.2",
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down
11 changes: 10 additions & 1 deletion src/fish_audio_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from .apis import Session
from .exceptions import HttpCodeErr
from .schemas import ASRRequest, TTSRequest, ReferenceAudio
from .websocket import WebSocketSession, AsyncWebSocketSession

__all__ = ["Session", "HttpCodeErr", "ReferenceAudio", "TTSRequest", "ASRRequest"]
__all__ = [
"Session",
"HttpCodeErr",
"ReferenceAudio",
"TTSRequest",
"ASRRequest",
"WebSocketSession",
"AsyncWebSocketSession",
]
6 changes: 6 additions & 0 deletions src/fish_audio_sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ def __init__(self, status: int, message: str):
self.status = status
self.message = message
super().__init__(f"{status} {message}")


class WebSocketErr(Exception):
"""
{"event": "finish", "reason": "error"} or WebSocketDisconnect
"""
14 changes: 14 additions & 0 deletions src/fish_audio_sdk/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,17 @@ class PackageEntity(BaseModel):
created_at: str
updated_at: str
finished_at: str


class StartEvent(BaseModel):
event: Literal["start"] = "start"
request: TTSRequest


class TextEvent(BaseModel):
event: Literal["text"] = "text"
text: str


class CloseEvent(BaseModel):
event: Literal["stop"] = "stop"
147 changes: 147 additions & 0 deletions src/fish_audio_sdk/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncGenerator, AsyncIterable, Generator, Iterable

import httpx
import ormsgpack
from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws

from .exceptions import WebSocketErr

from .schemas import CloseEvent, StartEvent, TTSRequest, TextEvent


class WebSocketSession:
def __init__(
self,
apikey: str,
*,
base_url: str = "https://api.fish.audio",
max_workers: int = 10,
):
self._apikey = apikey
self._base_url = base_url
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._client = httpx.Client(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def close(self):
self._client.close()

def tts(
self, request: TTSRequest, text_stream: Iterable[str]
) -> Generator[bytes, None, None]:
with connect_ws("/v1/tts/live", client=self._client) as ws:

def sender():
ws.send_bytes(
ormsgpack.packb(
StartEvent(request=request),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
for text in text_stream:
ws.send_bytes(
ormsgpack.packb(
TextEvent(text=text),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
ws.send_bytes(
ormsgpack.packb(
CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
)
)

sender_future = self._executor.submit(sender)

while True:
try:
message = ws.receive_bytes()
data = ormsgpack.unpackb(message)
match data["event"]:
case "audio":
yield data["audio"]
case "finish" if data["reason"] == "error":
raise WebSocketErr
case "finish" if data["reason"] == "stop":
break
except WebSocketDisconnect:
raise WebSocketErr

sender_future.result()


class AsyncWebSocketSession:
def __init__(
self,
apikey: str,
*,
base_url: str = "https://api.fish.audio",
):
self._apikey = apikey
self._base_url = base_url
self._client = httpx.AsyncClient(
base_url=self._base_url,
headers={"Authorization": f"Bearer {self._apikey}"},
)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()

async def close(self):
await self._client.aclose()

async def tts(
self, request: TTSRequest, text_stream: AsyncIterable[str]
) -> AsyncGenerator[bytes, None]:
async with aconnect_ws("/v1/tts/live", client=self._client) as ws:

async def sender():
await ws.send_bytes(
ormsgpack.packb(
StartEvent(request=request),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
async for text in text_stream:
await ws.send_bytes(
ormsgpack.packb(
TextEvent(text=text),
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
)
)
await ws.send_bytes(
ormsgpack.packb(
CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
)
)

sender_future = asyncio.get_running_loop().create_task(sender())

while True:
try:
message = await ws.receive_bytes()
data = ormsgpack.unpackb(message)
match data["event"]:
case "audio":
yield data["audio"]
case "finish" if data["reason"] == "error":
raise WebSocketErr
case "finish" if data["reason"] == "stop":
break
except WebSocketDisconnect:
raise WebSocketErr

await sender_future
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@

import pytest

from fish_audio_sdk.apis import Session
from fish_audio_sdk import Session, WebSocketSession, AsyncWebSocketSession

APIKEY = os.environ["APIKEY"]


@pytest.fixture
def session():
return Session(APIKEY)


@pytest.fixture
def sync_websocket():
return WebSocketSession(APIKEY)


@pytest.fixture
def async_websocket():
return AsyncWebSocketSession(APIKEY)
31 changes: 31 additions & 0 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fish_audio_sdk import TTSRequest, WebSocketSession, AsyncWebSocketSession

story = """
修炼了六千三百七十九年又三月零六天后,天门因她终于洞开。
她凭虚站立在黄山峰顶,因天门洞开而鼓起的飓风不停拍打着她身上的黑袍,在催促她快快登仙而去;黄山间壮阔的云海也随之翻涌,为这一场天地幸事欢呼雀跃。她没有抬头看向那似隐似现、若有若无、形态万千变化的天门,只是呆立在原处自顾自地看向远方。
"""


def test_tts(sync_websocket: WebSocketSession):
buffer = bytearray()

def stream():
for line in story.split("\n"):
yield line

for chunk in sync_websocket.tts(TTSRequest(text=""), stream()):
buffer.extend(chunk)
assert len(buffer) > 0


async def test_async_tts(async_websocket: AsyncWebSocketSession):
buffer = bytearray()

async def stream():
for line in story.split("\n"):
yield line

async for chunk in async_websocket.tts(TTSRequest(text=""), stream()):
buffer.extend(chunk)
assert len(buffer) > 0

0 comments on commit dd7d7fa

Please sign in to comment.