Skip to content

Commit

Permalink
feature(ai): add spandrel to enable GAN upscalers
Browse files Browse the repository at this point in the history
*add selected upscalers with permissive licenses
  • Loading branch information
ad-astra-video committed Aug 29, 2024
1 parent d97a354 commit f83693e
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 11 deletions.
165 changes: 155 additions & 10 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,112 @@
import logging
import os
from typing import List, Optional, Tuple
from enum import Enum

import PIL
import PIL.Image
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
is_turbo_model
)
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile
from spandrel import ModelLoader
import numpy as np
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile, Image

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""

BSRGAN_X2 = "bsrgan_x2"
BSRGAN_X4 = "bsrgan_x4"
NMKD_SUPERSCALE_SP_X4 = "nmkd_superscale_sp_x4"
NMKD_TYPESCALE_X8 = "nmkd_typescale_x4"
REALESRGAN_X2 = "realesrgan_x2"
REALESRGAN_X4 = "realesrgan_x4"
REALESRGAN_ANIME_X4 = "realesrgan_anime_x4"
REAL_HAT_GAN_SR_X4 = "real_hat_gan_sr_x4"
REAL_HAT_GAN_SR_X4_SHARPER = "real_hat_gan_sr_x4_sharper"
SCUNET_COLOR_GAN = "scunet_color_real_gan"
SCUNET_COLOR_PSNR = "scunet_color_real_psnr"
SWIN2SR_CLASSICAL_X2 = "swin2sr_classical_x2"
SWIN2SR_CLASSICAL_X4 = "swin2sr_classical_x4"


@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))
@classmethod
def get_model_file(cls, model):
match model:
case cls.BSRGAN_X2.value:
return hf_hub_download("ad-astra-video/upscalers", filename="BSRGANx2.pth", cache_dir=get_model_dir())
case cls.BSRGAN_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="BSRGANx4.pth", cache_dir=get_model_dir())
case cls.NMKD_SUPERSCALE_SP_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="4x_NMKD-Superscale-SP_178000_G.pth", cache_dir=get_model_dir())
case cls.NMKD_TYPESCALE_X8.value:
return hf_hub_download("ad-astra-video/upscalers", filename="8x_NMKD-Typescale_175k.pth", cache_dir=get_model_dir())
case cls.REALESRGAN_X2.value:
return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x2plus.pth", cache_dir=get_model_dir())
case cls.REALESRGAN_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x4plus.pth", cache_dir=get_model_dir())
case cls.REALESRGAN_ANIME_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x4plus_anime_6B.pth", cache_dir=get_model_dir())
case cls.REAL_HAT_GAN_SR_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Real_HAT_GAN_SRx4.pth", cache_dir=get_model_dir())
case cls.REAL_HAT_GAN_SR_X4_SHARPER.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Real_HAT_GAN_sharper.pth", cache_dir=get_model_dir())
case cls.SCUNET_COLOR_GAN.value:
return hf_hub_download("ad-astra-video/upscalers", filename="scunet_color_real_gan.pth", cache_dir=get_model_dir())
case cls.SCUNET_COLOR_PSNR.value:
return hf_hub_download("ad-astra-video/upscalers", filename="scunet_color_real_psnr.pth", cache_dir=get_model_dir())
case cls.SWIN2SR_CLASSICAL_X2.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Swin2SR_ClassicalSR_X2_64.pth", cache_dir=get_model_dir())
case cls.SWIN2SR_CLASSICAL_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Swin2SR_ClassicalSR_X4_64.pth", cache_dir=get_model_dir())

@classmethod
def get_model_scale(cls, model):
match model:
case cls.BSRGAN_X2.value:
return 2
case cls.BSRGAN_X4.value:
return 4
case cls.NMKD_SUPERSCALE_SP_X4.value:
return 4
case cls.NMKD_TYPESCALE_X8.value:
return 8
case cls.REALESRGAN_X2.value:
return 2
case cls.REALESRGAN_X4.value:
return 4
case cls.REALESRGAN_ANIME_X4.value:
return 4
case cls.REAL_HAT_GAN_SR_X4.value:
return 4
case cls.REAL_HAT_GAN_SR_X4_SHARPER.value:
return 4
case cls.SCUNET_COLOR_GAN.value:
return 4
case cls.SCUNET_COLOR_PSNR.value:
return 4
case cls.SWIN2SR_CLASSICAL_X2.value:
return 2
case cls.SWIN2SR_CLASSICAL_X4.value:
return 4


class UpscalePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
Expand All @@ -42,9 +128,22 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
if self.model_id == "stabilityai/stable-diffusion-x4-upscaler":
self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
else:
if self.model_id in ModelName.list():
#use spandrel to load the model
model_file = ModelName.get_model_file(self.model_id)
logger.info(f"loading model file: {model_file}")
model = ModelLoader().load_from_file(model_file)
logger.info(f"model loaded, scale={model.scale}")
# send it to the GPU and put it in inference mode
model.cuda().eval()
self.ldm = model
else:
raise ValueError("Model not supported")

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
Expand Down Expand Up @@ -99,21 +198,33 @@ def __call__(
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

torch_device = get_torch_device()
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
kwargs["generator"] = torch.Generator(torch_device).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
torch.Generator(torch_device).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
del kwargs["num_inference_steps"]

output = self.ldm(prompt, image=image, **kwargs)
if self.model_id == "stabilityai/stable-diffusion-x4-upscaler":
output = self.ldm(prompt, image=image, **kwargs)
else:
max_scale = self.get_max_scale_for_input(image)
if self.ldm.scale > max_scale:
raise ValueError("requested scale too high")

# Convert PIL image to NumPy array
img_tensor = self.pil_to_tensor(image)
img_tensor = img_tensor.to(torch_device)
output = self.ldm(img_tensor)
output = self.tensor_to_pil(output)
output.images = [output]

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
Expand All @@ -124,3 +235,37 @@ def __call__(

def __str__(self) -> str:
return f"UpscalePipeline model_id={self.model_id}"

# Load and preprocess the image
def pil_to_tensor(self, image: PIL.Image) -> torch.Tensor:
# Convert PIL image to NumPy array
img = np.array(image).astype(np.float32) / 255.0

# Convert to tensor and reorder dimensions to (C, H, W)
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).unsqueeze(0)

return img

# Postprocess the output from ESRGAN
def tensor_to_pil(self, tensor: torch.Tensor) -> PIL.Image:
# Remove the batch dimension and reorder dimensions to (H, W, C)
img = tensor.squeeze(0).cpu().numpy()
img = np.transpose(img, (1, 2, 0))

# Convert the pixel values to the [0, 255] range and convert to uint8
img = (img * 255.0).clip(0, 255).astype(np.uint8)

return Image.fromarray(img)

def get_max_scale_for_input(self, image: PIL.Image) -> int:
w, h = image.size
if (w*h) > 1048576: #1024x1024
return 2
elif (w*h) > 65536: #256x256
return 4
else:
return 8




3 changes: 2 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ function download_alpha_models() {

# Download upscale models
huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models

huggingface-cli download ad-astra-video/upscalers --cache-dir models

# Download audio-to-text models.
huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models

Expand Down
1 change: 1 addition & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ numpy==1.26.4
av==12.1.0
sentencepiece== 0.2.0
protobuf==5.27.2
spandrel==0.3.4

0 comments on commit f83693e

Please sign in to comment.