From f83693e6e15fb722ef9a2468492be530ab26cdea Mon Sep 17 00:00:00 2001 From: Brad P Date: Sat, 24 Aug 2024 10:51:42 -0500 Subject: [PATCH] feature(ai): add spandrel to enable GAN upscalers *add selected upscalers with permissive licenses --- runner/app/pipelines/upscale.py | 165 ++++++++++++++++++++++++++++++-- runner/dl_checkpoints.sh | 3 +- runner/requirements.txt | 1 + 3 files changed, 158 insertions(+), 11 deletions(-) diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index e36e4606..c3979d59 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -1,8 +1,10 @@ import logging import os from typing import List, Optional, Tuple +from enum import Enum import PIL +import PIL.Image import torch from app.pipelines.base import Pipeline from app.pipelines.utils import ( @@ -10,17 +12,101 @@ get_model_dir, get_torch_device, is_lightning_model, - is_turbo_model, + is_turbo_model ) from diffusers import StableDiffusionUpscalePipeline -from huggingface_hub import file_download -from PIL import ImageFile +from spandrel import ModelLoader +import numpy as np +from huggingface_hub import file_download, hf_hub_download +from PIL import ImageFile, Image ImageFile.LOAD_TRUNCATED_IMAGES = True logger = logging.getLogger(__name__) +class ModelName(Enum): + """Enumeration mapping model names to their corresponding IDs.""" + BSRGAN_X2 = "bsrgan_x2" + BSRGAN_X4 = "bsrgan_x4" + NMKD_SUPERSCALE_SP_X4 = "nmkd_superscale_sp_x4" + NMKD_TYPESCALE_X8 = "nmkd_typescale_x4" + REALESRGAN_X2 = "realesrgan_x2" + REALESRGAN_X4 = "realesrgan_x4" + REALESRGAN_ANIME_X4 = "realesrgan_anime_x4" + REAL_HAT_GAN_SR_X4 = "real_hat_gan_sr_x4" + REAL_HAT_GAN_SR_X4_SHARPER = "real_hat_gan_sr_x4_sharper" + SCUNET_COLOR_GAN = "scunet_color_real_gan" + SCUNET_COLOR_PSNR = "scunet_color_real_psnr" + SWIN2SR_CLASSICAL_X2 = "swin2sr_classical_x2" + SWIN2SR_CLASSICAL_X4 = "swin2sr_classical_x4" + + + @classmethod + def list(cls): + """Return a list of all model IDs.""" + return list(map(lambda c: c.value, cls)) + @classmethod + def get_model_file(cls, model): + match model: + case cls.BSRGAN_X2.value: + return hf_hub_download("ad-astra-video/upscalers", filename="BSRGANx2.pth", cache_dir=get_model_dir()) + case cls.BSRGAN_X4.value: + return hf_hub_download("ad-astra-video/upscalers", filename="BSRGANx4.pth", cache_dir=get_model_dir()) + case cls.NMKD_SUPERSCALE_SP_X4.value: + return hf_hub_download("ad-astra-video/upscalers", filename="4x_NMKD-Superscale-SP_178000_G.pth", cache_dir=get_model_dir()) + case cls.NMKD_TYPESCALE_X8.value: + return hf_hub_download("ad-astra-video/upscalers", filename="8x_NMKD-Typescale_175k.pth", cache_dir=get_model_dir()) + case cls.REALESRGAN_X2.value: + return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x2plus.pth", cache_dir=get_model_dir()) + case cls.REALESRGAN_X4.value: + return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x4plus.pth", cache_dir=get_model_dir()) + case cls.REALESRGAN_ANIME_X4.value: + return hf_hub_download("ad-astra-video/upscalers", filename="RealESRGAN_x4plus_anime_6B.pth", cache_dir=get_model_dir()) + case cls.REAL_HAT_GAN_SR_X4.value: + return hf_hub_download("ad-astra-video/upscalers", filename="Real_HAT_GAN_SRx4.pth", cache_dir=get_model_dir()) + case cls.REAL_HAT_GAN_SR_X4_SHARPER.value: + return hf_hub_download("ad-astra-video/upscalers", filename="Real_HAT_GAN_sharper.pth", cache_dir=get_model_dir()) + case cls.SCUNET_COLOR_GAN.value: + return hf_hub_download("ad-astra-video/upscalers", filename="scunet_color_real_gan.pth", cache_dir=get_model_dir()) + case cls.SCUNET_COLOR_PSNR.value: + return hf_hub_download("ad-astra-video/upscalers", filename="scunet_color_real_psnr.pth", cache_dir=get_model_dir()) + case cls.SWIN2SR_CLASSICAL_X2.value: + return hf_hub_download("ad-astra-video/upscalers", filename="Swin2SR_ClassicalSR_X2_64.pth", cache_dir=get_model_dir()) + case cls.SWIN2SR_CLASSICAL_X4.value: + return hf_hub_download("ad-astra-video/upscalers", filename="Swin2SR_ClassicalSR_X4_64.pth", cache_dir=get_model_dir()) + + @classmethod + def get_model_scale(cls, model): + match model: + case cls.BSRGAN_X2.value: + return 2 + case cls.BSRGAN_X4.value: + return 4 + case cls.NMKD_SUPERSCALE_SP_X4.value: + return 4 + case cls.NMKD_TYPESCALE_X8.value: + return 8 + case cls.REALESRGAN_X2.value: + return 2 + case cls.REALESRGAN_X4.value: + return 4 + case cls.REALESRGAN_ANIME_X4.value: + return 4 + case cls.REAL_HAT_GAN_SR_X4.value: + return 4 + case cls.REAL_HAT_GAN_SR_X4_SHARPER.value: + return 4 + case cls.SCUNET_COLOR_GAN.value: + return 4 + case cls.SCUNET_COLOR_PSNR.value: + return 4 + case cls.SWIN2SR_CLASSICAL_X2.value: + return 2 + case cls.SWIN2SR_CLASSICAL_X4.value: + return 4 + + class UpscalePipeline(Pipeline): def __init__(self, model_id: str): self.model_id = model_id @@ -42,9 +128,22 @@ def __init__(self, model_id: str): kwargs["torch_dtype"] = torch.float16 kwargs["variant"] = "fp16" - self.ldm = StableDiffusionUpscalePipeline.from_pretrained( - model_id, **kwargs - ).to(torch_device) + if self.model_id == "stabilityai/stable-diffusion-x4-upscaler": + self.ldm = StableDiffusionUpscalePipeline.from_pretrained( + model_id, **kwargs + ).to(torch_device) + else: + if self.model_id in ModelName.list(): + #use spandrel to load the model + model_file = ModelName.get_model_file(self.model_id) + logger.info(f"loading model file: {model_file}") + model = ModelLoader().load_from_file(model_file) + logger.info(f"model loaded, scale={model.scale}") + # send it to the GPU and put it in inference mode + model.cuda().eval() + self.ldm = model + else: + raise ValueError("Model not supported") sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true" @@ -99,21 +198,33 @@ def __call__( seed = kwargs.pop("seed", None) num_inference_steps = kwargs.get("num_inference_steps", None) safety_check = kwargs.pop("safety_check", True) - + torch_device = get_torch_device() if seed is not None: if isinstance(seed, int): - kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( + kwargs["generator"] = torch.Generator(torch_device).manual_seed( seed ) elif isinstance(seed, list): kwargs["generator"] = [ - torch.Generator(get_torch_device()).manual_seed(s) for s in seed + torch.Generator(torch_device).manual_seed(s) for s in seed ] if num_inference_steps is None or num_inference_steps < 1: del kwargs["num_inference_steps"] - output = self.ldm(prompt, image=image, **kwargs) + if self.model_id == "stabilityai/stable-diffusion-x4-upscaler": + output = self.ldm(prompt, image=image, **kwargs) + else: + max_scale = self.get_max_scale_for_input(image) + if self.ldm.scale > max_scale: + raise ValueError("requested scale too high") + + # Convert PIL image to NumPy array + img_tensor = self.pil_to_tensor(image) + img_tensor = img_tensor.to(torch_device) + output = self.ldm(img_tensor) + output = self.tensor_to_pil(output) + output.images = [output] if safety_check: _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) @@ -124,3 +235,37 @@ def __call__( def __str__(self) -> str: return f"UpscalePipeline model_id={self.model_id}" + + # Load and preprocess the image + def pil_to_tensor(self, image: PIL.Image) -> torch.Tensor: + # Convert PIL image to NumPy array + img = np.array(image).astype(np.float32) / 255.0 + + # Convert to tensor and reorder dimensions to (C, H, W) + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).unsqueeze(0) + + return img + + # Postprocess the output from ESRGAN + def tensor_to_pil(self, tensor: torch.Tensor) -> PIL.Image: + # Remove the batch dimension and reorder dimensions to (H, W, C) + img = tensor.squeeze(0).cpu().numpy() + img = np.transpose(img, (1, 2, 0)) + + # Convert the pixel values to the [0, 255] range and convert to uint8 + img = (img * 255.0).clip(0, 255).astype(np.uint8) + + return Image.fromarray(img) + + def get_max_scale_for_input(self, image: PIL.Image) -> int: + w, h = image.size + if (w*h) > 1048576: #1024x1024 + return 2 + elif (w*h) > 65536: #256x256 + return 4 + else: + return 8 + + + + diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 9fe40837..8f41e0de 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -30,7 +30,8 @@ function download_alpha_models() { # Download upscale models huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models - + huggingface-cli download ad-astra-video/upscalers --cache-dir models + # Download audio-to-text models. huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models diff --git a/runner/requirements.txt b/runner/requirements.txt index 24f2442f..8d924133 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -17,3 +17,4 @@ numpy==1.26.4 av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 +spandrel==0.3.4 \ No newline at end of file