Skip to content

Commit

Permalink
add text-to-audio pipeline and dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
pschroedl committed Jul 24, 2024
1 parent 8c03423 commit aeb257f
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 29 deletions.
9 changes: 8 additions & 1 deletion runner/app/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
from contextlib import asynccontextmanager

from app.routes import health
from fastapi import FastAPI
from fastapi.routing import APIRoute


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,6 +52,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "text-to-audio":
from app.pipelines.text_to_audio import TextToAudioPipeline
return TextToAudioPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -82,6 +85,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "text-to-audio":
from app.routes import text_to_audio

return text_to_audio.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
48 changes: 48 additions & 0 deletions runner/app/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import uuid
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir
from transformers import FastSpeech2ConformerTokenizer, FastSpeech2ConformerModel, FastSpeech2ConformerHifiGan
import soundfile as sf
import os
import logging

logger = logging.getLogger(__name__)

class TextToAudioPipeline(Pipeline):
def __init__(self):
self.device = get_torch_device()
# preload FastSpeech 2 & hifigan
self.TTS_tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir())
self.TTS_model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()).to(self.device)
self.TTS_hifigan = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan", cache_dir=get_model_dir()).to(self.device)


def __call__(self, text):
# generate unique filename
unique_audio_filename = f"{uuid.uuid4()}.wav"
audio_path = os.path.join("/tmp/", unique_audio_filename)

self.generate_audio(text, audio_path)

return audio_path

def generate_audio(self, text, output_file_name):
# Tokenize input text
inputs = self.TTS_tokenizer(text, return_tensors="pt").to(self.device)

# Ensure input IDs remain in Long tensor type
input_ids = inputs["input_ids"].to(self.device)

# Generate spectrogram
output_dict = self.TTS_model(input_ids, return_dict=True)
spectrogram = output_dict["spectrogram"]

# Convert spectrogram to waveform
waveform = self.TTS_hifigan(spectrogram)

sf.write(output_file_name, waveform.squeeze().detach().cpu().numpy(), samplerate=22050)
return output_file_name

def __str__(self) -> str:
return f"TextToAudioPipeline model_id={self.model_id}"

1 change: 1 addition & 0 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ def check_nsfw_images(
clip_input=safety_checker_input.pixel_values.to(self._dtype),
)
return images, has_nsfw_concept

57 changes: 57 additions & 0 deletions runner/app/routes/text_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional, Union, List
from fastapi import Depends, APIRouter, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from app.pipelines.base import Pipeline
from app.dependencies import get_pipeline
import logging
import random
import json
import os

class HTTPError(BaseModel):
detail: str

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {
400: {"content": {"application/json": {"schema": HTTPError.schema()}}},
500: {"content": {"application/json": {"schema": HTTPError.schema()}}},
200: {
"content": {
"audio/mp4": {},
}
}
}

@router.post("/text-to-audio", responses=responses)
async def TextToAudio(
text_input: Optional[str] = Form(None),
pipeline: Pipeline = Depends(get_pipeline),
):

try:
result = pipeline(
text_input
)

except Exception as e:
logger.error(f"TextToAudioPipeline error: {e}")
return JSONResponse(
status_code=500,
content={
"detail": f"Internal Server Error: {str(e)}"
},
)

if os.path.exists(result):
return FileResponse(path=result, media_type='audio/mp4', filename="generated_audio.mp4")
else:
return JSONResponse(
status_code=400,
content={
"detail": f"no output found for {result}"
},
)
12 changes: 12 additions & 0 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ class ImageResponse(BaseModel):
class VideoResponse(BaseModel):
frames: List[List[Media]]

class AudioResponse(BaseModel):
audio: Media

class chunk(BaseModel):
timestamp: tuple
text: str


class TextResponse(BaseModel):
text: str
chunks: List[chunk]


class chunk(BaseModel):
timestamp: tuple
Expand Down
7 changes: 7 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ function download_alpha_models() {

# Download upscale models
huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models


# Download FastSpeech 2 and HiFi-GAN models
huggingface-cli download facebook/fastspeech2-en-ljspeech --include "*.bin" "*.json" --cache-dir models/fastspeech2
huggingface-cli download facebook/hifigan --include "*.bin" "*.json" --cache-dir models/hifigan

# Download audio-to-text models.
huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models
Expand All @@ -39,6 +44,8 @@ function download_alpha_models() {
# Download image-to-video models (token-gated).
check_hf_auth
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt-1-1 --include "*.fp16.safetensors" "*.json" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}


}

# Download all models.
Expand Down
2 changes: 2 additions & 0 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from app.main import app, use_route_names_as_operation_ids
from app.routes import (
audio_to_text,
text_to_audio,
health,
image_to_image,
image_to_video,
Expand Down Expand Up @@ -85,6 +86,7 @@ def write_openapi(fname, entrypoint="runner"):
app.include_router(image_to_video.router)
app.include_router(upscale.router)
app.include_router(audio_to_text.router)
app.include_router(text_to_audio.router)

use_route_names_as_operation_ids(app)

Expand Down
107 changes: 93 additions & 14 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,87 @@
}
]
}
},
"/text-to-audio": {
"post": {
"summary": "Texttoaudio",
"operationId": "TextToAudio",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"allOf": [
{
"$ref": "#/components/schemas/Body_TextToAudio_text_to_audio_post"
}
],
"title": "Body"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {}
},
"audio/mp4": {}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"properties": {
"detail": {
"type": "string",
"title": "Detail"
}
},
"type": "object",
"required": [
"detail"
],
"title": "HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"properties": {
"detail": {
"type": "string",
"title": "Detail"
}
},
"type": "object",
"required": [
"detail"
],
"title": "HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
}
},
"components": {
Expand All @@ -421,24 +502,22 @@
],
"title": "APIError"
},
"Body_audio_to_text_audio_to_text_post": {
"Body_TextToAudio_text_to_audio_post": {
"properties": {
"audio": {
"type": "string",
"format": "binary",
"title": "Audio"
},
"model_id": {
"type": "string",
"title": "Model Id",
"default": ""
"text_input": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Text Input"
}
},
"type": "object",
"required": [
"audio"
],
"title": "Body_audio_to_text_audio_to_text_post"
"title": "Body_TextToAudio_text_to_audio_post"
},
"Body_image_to_image_image_to_image_post": {
"properties": {
Expand Down
1 change: 1 addition & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ numpy==1.26.4
av==12.1.0
sentencepiece== 0.2.0
protobuf==5.27.2
soundfile
Loading

0 comments on commit aeb257f

Please sign in to comment.