-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2d8c04b
commit dd7d7fa
Showing
8 changed files
with
252 additions
and
3 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,14 @@ | ||
from .apis import Session | ||
from .exceptions import HttpCodeErr | ||
from .schemas import ASRRequest, TTSRequest, ReferenceAudio | ||
from .websocket import WebSocketSession, AsyncWebSocketSession | ||
|
||
__all__ = ["Session", "HttpCodeErr", "ReferenceAudio", "TTSRequest", "ASRRequest"] | ||
__all__ = [ | ||
"Session", | ||
"HttpCodeErr", | ||
"ReferenceAudio", | ||
"TTSRequest", | ||
"ASRRequest", | ||
"WebSocketSession", | ||
"AsyncWebSocketSession", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import asyncio | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import AsyncGenerator, AsyncIterable, Generator, Iterable | ||
|
||
import httpx | ||
import ormsgpack | ||
from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws | ||
|
||
from .exceptions import WebSocketErr | ||
|
||
from .schemas import CloseEvent, StartEvent, TTSRequest, TextEvent | ||
|
||
|
||
class WebSocketSession: | ||
def __init__( | ||
self, | ||
apikey: str, | ||
*, | ||
base_url: str = "https://api.fish.audio", | ||
max_workers: int = 10, | ||
): | ||
self._apikey = apikey | ||
self._base_url = base_url | ||
self._executor = ThreadPoolExecutor(max_workers=max_workers) | ||
self._client = httpx.Client( | ||
base_url=self._base_url, | ||
headers={"Authorization": f"Bearer {self._apikey}"}, | ||
) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
self.close() | ||
|
||
def close(self): | ||
self._client.close() | ||
|
||
def tts( | ||
self, request: TTSRequest, text_stream: Iterable[str] | ||
) -> Generator[bytes, None, None]: | ||
with connect_ws("/v1/tts/live", client=self._client) as ws: | ||
|
||
def sender(): | ||
ws.send_bytes( | ||
ormsgpack.packb( | ||
StartEvent(request=request), | ||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | ||
) | ||
) | ||
for text in text_stream: | ||
ws.send_bytes( | ||
ormsgpack.packb( | ||
TextEvent(text=text), | ||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | ||
) | ||
) | ||
ws.send_bytes( | ||
ormsgpack.packb( | ||
CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC | ||
) | ||
) | ||
|
||
sender_future = self._executor.submit(sender) | ||
|
||
while True: | ||
try: | ||
message = ws.receive_bytes() | ||
data = ormsgpack.unpackb(message) | ||
match data["event"]: | ||
case "audio": | ||
yield data["audio"] | ||
case "finish" if data["reason"] == "error": | ||
raise WebSocketErr | ||
case "finish" if data["reason"] == "stop": | ||
break | ||
except WebSocketDisconnect: | ||
raise WebSocketErr | ||
|
||
sender_future.result() | ||
|
||
|
||
class AsyncWebSocketSession: | ||
def __init__( | ||
self, | ||
apikey: str, | ||
*, | ||
base_url: str = "https://api.fish.audio", | ||
): | ||
self._apikey = apikey | ||
self._base_url = base_url | ||
self._client = httpx.AsyncClient( | ||
base_url=self._base_url, | ||
headers={"Authorization": f"Bearer {self._apikey}"}, | ||
) | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback): | ||
await self.close() | ||
|
||
async def close(self): | ||
await self._client.aclose() | ||
|
||
async def tts( | ||
self, request: TTSRequest, text_stream: AsyncIterable[str] | ||
) -> AsyncGenerator[bytes, None]: | ||
async with aconnect_ws("/v1/tts/live", client=self._client) as ws: | ||
|
||
async def sender(): | ||
await ws.send_bytes( | ||
ormsgpack.packb( | ||
StartEvent(request=request), | ||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | ||
) | ||
) | ||
async for text in text_stream: | ||
await ws.send_bytes( | ||
ormsgpack.packb( | ||
TextEvent(text=text), | ||
option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | ||
) | ||
) | ||
await ws.send_bytes( | ||
ormsgpack.packb( | ||
CloseEvent(), option=ormsgpack.OPT_SERIALIZE_PYDANTIC | ||
) | ||
) | ||
|
||
sender_future = asyncio.get_running_loop().create_task(sender()) | ||
|
||
while True: | ||
try: | ||
message = await ws.receive_bytes() | ||
data = ormsgpack.unpackb(message) | ||
match data["event"]: | ||
case "audio": | ||
yield data["audio"] | ||
case "finish" if data["reason"] == "error": | ||
raise WebSocketErr | ||
case "finish" if data["reason"] == "stop": | ||
break | ||
except WebSocketDisconnect: | ||
raise WebSocketErr | ||
|
||
await sender_future |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from fish_audio_sdk import TTSRequest, WebSocketSession, AsyncWebSocketSession | ||
|
||
story = """ | ||
修炼了六千三百七十九年又三月零六天后,天门因她终于洞开。 | ||
她凭虚站立在黄山峰顶,因天门洞开而鼓起的飓风不停拍打着她身上的黑袍,在催促她快快登仙而去;黄山间壮阔的云海也随之翻涌,为这一场天地幸事欢呼雀跃。她没有抬头看向那似隐似现、若有若无、形态万千变化的天门,只是呆立在原处自顾自地看向远方。 | ||
""" | ||
|
||
|
||
def test_tts(sync_websocket: WebSocketSession): | ||
buffer = bytearray() | ||
|
||
def stream(): | ||
for line in story.split("\n"): | ||
yield line | ||
|
||
for chunk in sync_websocket.tts(TTSRequest(text=""), stream()): | ||
buffer.extend(chunk) | ||
assert len(buffer) > 0 | ||
|
||
|
||
async def test_async_tts(async_websocket: AsyncWebSocketSession): | ||
buffer = bytearray() | ||
|
||
async def stream(): | ||
for line in story.split("\n"): | ||
yield line | ||
|
||
async for chunk in async_websocket.tts(TTSRequest(text=""), stream()): | ||
buffer.extend(chunk) | ||
assert len(buffer) > 0 |