-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add text-to-audio pipeline and dependencies
- Loading branch information
Showing
10 changed files
with
275 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import uuid | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.util import get_torch_device, get_model_dir | ||
from transformers import FastSpeech2ConformerTokenizer, FastSpeech2ConformerModel, FastSpeech2ConformerHifiGan | ||
import soundfile as sf | ||
import os | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class TextToAudioPipeline(Pipeline): | ||
def __init__(self): | ||
self.device = get_torch_device() | ||
# preload FastSpeech 2 & hifigan | ||
self.TTS_tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()) | ||
self.TTS_model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()).to(self.device) | ||
self.TTS_hifigan = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan", cache_dir=get_model_dir()).to(self.device) | ||
|
||
|
||
def __call__(self, text): | ||
# generate unique filename | ||
unique_audio_filename = f"{uuid.uuid4()}.wav" | ||
audio_path = os.path.join("/tmp/", unique_audio_filename) | ||
|
||
self.generate_audio(text, audio_path) | ||
|
||
return audio_path | ||
|
||
def generate_audio(self, text, output_file_name): | ||
# Tokenize input text | ||
inputs = self.TTS_tokenizer(text, return_tensors="pt").to(self.device) | ||
|
||
# Ensure input IDs remain in Long tensor type | ||
input_ids = inputs["input_ids"].to(self.device) | ||
|
||
# Generate spectrogram | ||
output_dict = self.TTS_model(input_ids, return_dict=True) | ||
spectrogram = output_dict["spectrogram"] | ||
|
||
# Convert spectrogram to waveform | ||
waveform = self.TTS_hifigan(spectrogram) | ||
|
||
sf.write(output_file_name, waveform.squeeze().detach().cpu().numpy(), samplerate=22050) | ||
return output_file_name | ||
|
||
def __str__(self) -> str: | ||
return f"TextToAudioPipeline model_id={self.model_id}" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Optional, Union, List | ||
from fastapi import Depends, APIRouter, UploadFile, File, Form, HTTPException | ||
from fastapi.responses import FileResponse, JSONResponse | ||
from pydantic import BaseModel | ||
from app.pipelines.base import Pipeline | ||
from app.dependencies import get_pipeline | ||
import logging | ||
import random | ||
import json | ||
import os | ||
|
||
class HTTPError(BaseModel): | ||
detail: str | ||
|
||
router = APIRouter() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
responses = { | ||
400: {"content": {"application/json": {"schema": HTTPError.schema()}}}, | ||
500: {"content": {"application/json": {"schema": HTTPError.schema()}}}, | ||
200: { | ||
"content": { | ||
"audio/mp4": {}, | ||
} | ||
} | ||
} | ||
|
||
@router.post("/text-to-audio", responses=responses) | ||
async def TextToAudio( | ||
text_input: Optional[str] = Form(None), | ||
pipeline: Pipeline = Depends(get_pipeline), | ||
): | ||
|
||
try: | ||
result = pipeline( | ||
text_input | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"TextToAudioPipeline error: {e}") | ||
return JSONResponse( | ||
status_code=500, | ||
content={ | ||
"detail": f"Internal Server Error: {str(e)}" | ||
}, | ||
) | ||
|
||
if os.path.exists(result): | ||
return FileResponse(path=result, media_type='audio/mp4', filename="generated_audio.mp4") | ||
else: | ||
return JSONResponse( | ||
status_code=400, | ||
content={ | ||
"detail": f"no output found for {result}" | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,4 @@ numpy==1.26.4 | |
av==12.1.0 | ||
sentencepiece== 0.2.0 | ||
protobuf==5.27.2 | ||
soundfile |
Oops, something went wrong.