Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(ai): enable GAN upscalers with spandrel #176

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 155 additions & 10 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,112 @@
import logging
import os
from typing import List, Optional, Tuple
from enum import Enum

import PIL
import PIL.Image
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
is_turbo_model
)
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile
from spandrel import ModelLoader
import numpy as np
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile, Image

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""

BSRGAN_X2 = "bsrgan_x2"
BSRGAN_X4 = "bsrgan_x4"
NMKD_SUPERSCALE_SP_X4 = "nmkd_superscale_sp_x4"
NMKD_TYPESCALE_X8 = "nmkd_typescale_x4"
REALESRGAN_X2 = "realesrgan_x2"
REALESRGAN_X4 = "realesrgan_x4"
REALESRGAN_ANIME_X4 = "realesrgan_anime_x4"
REAL_HAT_GAN_SR_X4 = "real_hat_gan_sr_x4"
REAL_HAT_GAN_SR_X4_SHARPER = "real_hat_gan_sr_x4_sharper"
SCUNET_COLOR_GAN = "scunet_color_real_gan"
SCUNET_COLOR_PSNR = "scunet_color_real_psnr"
SWIN2SR_CLASSICAL_X2 = "swin2sr_classical_x2"
SWIN2SR_CLASSICAL_X4 = "swin2sr_classical_x4"


@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))
@classmethod
def get_model_file(cls, model):
match model:
case cls.BSRGAN_X2.value:
return hf_hub_download("ad-astra-video/upscalers", filename="BSRGANx2.pth", cache_dir=get_model_dir())
case cls.BSRGAN_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="BSRGANx4.pth", cache_dir=get_model_dir())
case cls.NMKD_SUPERSCALE_SP_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="4x_NMKD-Superscale-SP_178000_G.pth", cache_dir=get_model_dir())
case cls.NMKD_TYPESCALE_X8.value:
return hf_hub_download("ad-astra-video/upscalers", filename="8x_NMKD-Typescale_175k.pth", cache_dir=get_model_dir())
case cls.REALESRGAN_X2.value:
return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x2plus.pth", cache_dir=get_model_dir())
case cls.REALESRGAN_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x4plus.pth", cache_dir=get_model_dir())
case cls.REALESRGAN_ANIME_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x4plus_anime_6B.pth", cache_dir=get_model_dir())
case cls.REAL_HAT_GAN_SR_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Real_HAT_GAN_SRx4.pth", cache_dir=get_model_dir())
case cls.REAL_HAT_GAN_SR_X4_SHARPER.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Real_HAT_GAN_sharper.pth", cache_dir=get_model_dir())
case cls.SCUNET_COLOR_GAN.value:
return hf_hub_download("ad-astra-video/upscalers", filename="scunet_color_real_gan.pth", cache_dir=get_model_dir())
case cls.SCUNET_COLOR_PSNR.value:
return hf_hub_download("ad-astra-video/upscalers", filename="scunet_color_real_psnr.pth", cache_dir=get_model_dir())
case cls.SWIN2SR_CLASSICAL_X2.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Swin2SR_ClassicalSR_X2_64.pth", cache_dir=get_model_dir())
case cls.SWIN2SR_CLASSICAL_X4.value:
return hf_hub_download("ad-astra-video/upscalers", filename="Swin2SR_ClassicalSR_X4_64.pth", cache_dir=get_model_dir())

@classmethod
def get_model_scale(cls, model):
match model:
case cls.BSRGAN_X2.value:
return 2
case cls.BSRGAN_X4.value:
return 4
case cls.NMKD_SUPERSCALE_SP_X4.value:
return 4
case cls.NMKD_TYPESCALE_X8.value:
return 8
case cls.REALESRGAN_X2.value:
return 2
case cls.REALESRGAN_X4.value:
return 4
case cls.REALESRGAN_ANIME_X4.value:
return 4
case cls.REAL_HAT_GAN_SR_X4.value:
return 4
case cls.REAL_HAT_GAN_SR_X4_SHARPER.value:
return 4
case cls.SCUNET_COLOR_GAN.value:
return 4
case cls.SCUNET_COLOR_PSNR.value:
return 4
case cls.SWIN2SR_CLASSICAL_X2.value:
return 2
case cls.SWIN2SR_CLASSICAL_X4.value:
return 4


class UpscalePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
Expand All @@ -42,9 +128,22 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
if self.model_id == "stabilityai/stable-diffusion-x4-upscaler":
self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
else:
if self.model_id in ModelName.list():
#use spandrel to load the model
model_file = ModelName.get_model_file(self.model_id)
logger.info(f"loading model file: {model_file}")
model = ModelLoader().load_from_file(model_file)
logger.info(f"model loaded, scale={model.scale}")
# send it to the GPU and put it in inference mode
model.cuda().eval()
self.ldm = model
else:
raise ValueError("Model not supported")

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
Expand Down Expand Up @@ -99,21 +198,33 @@ def __call__(
seed = kwargs.pop("seed", None)
num_inference_steps = kwargs.get("num_inference_steps", None)
safety_check = kwargs.pop("safety_check", True)

torch_device = get_torch_device()
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
kwargs["generator"] = torch.Generator(torch_device).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
torch.Generator(torch_device).manual_seed(s) for s in seed
]

if num_inference_steps is None or num_inference_steps < 1:
del kwargs["num_inference_steps"]

output = self.ldm(prompt, image=image, **kwargs)
if self.model_id == "stabilityai/stable-diffusion-x4-upscaler":
output = self.ldm(prompt, image=image, **kwargs)
else:
max_scale = self.get_max_scale_for_input(image)
if self.ldm.scale > max_scale:
raise ValueError("requested scale too high")

# Convert PIL image to NumPy array
img_tensor = self.pil_to_tensor(image)
img_tensor = img_tensor.to(torch_device)
output = self.ldm(img_tensor)
output = self.tensor_to_pil(output)
output.images = [output]

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
Expand All @@ -124,3 +235,37 @@ def __call__(

def __str__(self) -> str:
return f"UpscalePipeline model_id={self.model_id}"

# Load and preprocess the image
def pil_to_tensor(self, image: PIL.Image) -> torch.Tensor:
# Convert PIL image to NumPy array
img = np.array(image).astype(np.float32) / 255.0

# Convert to tensor and reorder dimensions to (C, H, W)
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).unsqueeze(0)

return img

# Postprocess the output from ESRGAN
def tensor_to_pil(self, tensor: torch.Tensor) -> PIL.Image:
# Remove the batch dimension and reorder dimensions to (H, W, C)
img = tensor.squeeze(0).cpu().numpy()
img = np.transpose(img, (1, 2, 0))

# Convert the pixel values to the [0, 255] range and convert to uint8
img = (img * 255.0).clip(0, 255).astype(np.uint8)

return Image.fromarray(img)

def get_max_scale_for_input(self, image: PIL.Image) -> int:
w, h = image.size
if (w*h) > 1048576: #1024x1024
return 2
elif (w*h) > 65536: #256x256
return 4
else:
return 8




3 changes: 2 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ function download_alpha_models() {

# Download upscale models
huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models

huggingface-cli download ad-astra-video/upscalers --cache-dir models

# Download audio-to-text models.
huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models

Expand Down
1 change: 1 addition & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ numpy==1.26.4
av==12.1.0
sentencepiece== 0.2.0
protobuf==5.27.2
spandrel==0.3.4
Loading