Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced Deep Residual Networks for single-image super-resolution - Keras 3 migration (Only Tensorflow Backend) #1920

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions examples/vision/edsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Enhanced Deep Residual Networks for single-image super-resolution
Author: Gitesh Chawda
Date created: 2022/04/07
Last modified: 2022/04/07
Last modified: 2024/08/27
Description: Training an EDSR model on the DIV2K Dataset.
Accelerator: GPU
"""
Expand Down Expand Up @@ -40,14 +40,18 @@
"""
## Imports
"""
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import layers
from keras import ops

AUTOTUNE = tf.data.AUTOTUNE

Expand Down Expand Up @@ -81,15 +85,15 @@ def flip_left_right(lowres_img, highres_img):
"""Flips Images to left and right."""

# Outputs random values from a uniform distribution in between 0 to 1
rn = tf.random.uniform(shape=(), maxval=1)
rn = keras.random.uniform(shape=(), maxval=1)
# If rn is less than 0.5 it returns original lowres_img and highres_img
# If rn is greater than 0.5 it returns flipped image
return tf.cond(
return ops.cond(
rn < 0.5,
lambda: (lowres_img, highres_img),
lambda: (
tf.image.flip_left_right(lowres_img),
tf.image.flip_left_right(highres_img),
ops.flip(lowres_img),
ops.flip(highres_img),
),
)

Expand All @@ -98,7 +102,9 @@ def random_rotate(lowres_img, highres_img):
"""Rotates Images by 90 degrees."""

# Outputs random values from uniform distribution in between 0 to 4
rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
rn = ops.cast(
keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32"
)
# Here rn signifies number of times the image(s) are rotated by 90 degrees
return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)

Expand All @@ -110,13 +116,19 @@ def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
high resolution images: 96x96
"""
lowres_crop_size = hr_crop_size // scale # 96//4=24
lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)
lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)

lowres_width = tf.random.uniform(
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32
lowres_width = ops.cast(
keras.random.uniform(
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
),
dtype="int32",
)
lowres_height = tf.random.uniform(
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32
lowres_height = ops.cast(
keras.random.uniform(
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
),
dtype="int32",
)

highres_width = lowres_width * scale
Expand Down Expand Up @@ -218,7 +230,7 @@ def PSNR(super_resolution, high_resolution):
"""


class EDSRModel(tf.keras.Model):
class EDSRModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
Expand All @@ -242,16 +254,16 @@ def train_step(self, data):

def predict_step(self, x):
# Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)
x = ops.cast(tf.expand_dims(x, axis=0), dtype="float32")
# Passing low resolution image to model
super_resolution_img = self(x, training=False)
# Clips the tensor from min(0) to max(255)
super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)
super_resolution_img = ops.clip(super_resolution_img, 0, 255)
# Rounds the values of a tensor to the nearest integer
super_resolution_img = tf.round(super_resolution_img)
super_resolution_img = ops.round(super_resolution_img)
# Removes dimensions of size 1 from the shape of a tensor and converting to uint8
super_resolution_img = tf.squeeze(
tf.cast(super_resolution_img, tf.uint8), axis=0
super_resolution_img = ops.squeeze(
ops.cast(super_resolution_img, dtype="uint8"), axis=0
)
return super_resolution_img

Expand All @@ -267,9 +279,9 @@ def ResBlock(inputs):
# Upsampling Block
def Upsampling(inputs, factor=2, **kwargs):
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
x = tf.nn.depth_to_space(x, block_size=factor)
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
x = tf.nn.depth_to_space(x, block_size=factor)
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
return x


Expand Down
Binary file added examples/vision/img/edsr/edsr_11_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_11_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/edsr/edsr_17_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
88 changes: 51 additions & 37 deletions examples/vision/ipynb/edsr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** Gitesh Chawda<br>\n",
"**Date created:** 2022/04/07<br>\n",
"**Last modified:** 2022/04/07<br>\n",
"**Last modified:** 2024/08/27<br>\n",
"**Description:** Training an EDSR model on the DIV2K Dataset."
]
},
Expand Down Expand Up @@ -39,7 +39,7 @@
"you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five\n",
"best-performing super-resolution methods based on PSNR scores. However, it has more\n",
"parameters and requires more computational power than other approaches.\n",
"It has a PSNR value (≈34db) that is slightly higher than ESPCN (≈32db).\n",
"It has a PSNR value (\u224834db) that is slightly higher than ESPCN (\u224832db).\n",
"As per the survey paper, EDSR performs better than ESPCN.\n",
"\n",
"Paper:\n",
Expand All @@ -60,19 +60,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"\n",
"AUTOTUNE = tf.data.AUTOTUNE"
]
Expand All @@ -93,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -123,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -134,15 +139,15 @@
" \"\"\"Flips Images to left and right.\"\"\"\n",
"\n",
" # Outputs random values from a uniform distribution in between 0 to 1\n",
" rn = tf.random.uniform(shape=(), maxval=1)\n",
" rn = keras.random.uniform(shape=(), maxval=1)\n",
" # If rn is less than 0.5 it returns original lowres_img and highres_img\n",
" # If rn is greater than 0.5 it returns flipped image\n",
" return tf.cond(\n",
" return ops.cond(\n",
" rn < 0.5,\n",
" lambda: (lowres_img, highres_img),\n",
" lambda: (\n",
" tf.image.flip_left_right(lowres_img),\n",
" tf.image.flip_left_right(highres_img),\n",
" ops.flip(lowres_img),\n",
" ops.flip(highres_img),\n",
" ),\n",
" )\n",
"\n",
Expand All @@ -151,7 +156,9 @@
" \"\"\"Rotates Images by 90 degrees.\"\"\"\n",
"\n",
" # Outputs random values from uniform distribution in between 0 to 4\n",
" rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)\n",
" rn = ops.cast(\n",
" keras.random.uniform(shape=(), maxval=4, dtype=\"float32\"), dtype=\"int32\"\n",
" )\n",
" # Here rn signifies number of times the image(s) are rotated by 90 degrees\n",
" return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)\n",
"\n",
Expand All @@ -163,13 +170,19 @@
" high resolution images: 96x96\n",
" \"\"\"\n",
" lowres_crop_size = hr_crop_size // scale # 96//4=24\n",
" lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)\n",
" lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)\n",
"\n",
" lowres_width = tf.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32\n",
" lowres_width = ops.cast(\n",
" keras.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=\"float32\"\n",
" ),\n",
" dtype=\"int32\",\n",
" )\n",
" lowres_height = tf.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32\n",
" lowres_height = ops.cast(\n",
" keras.random.uniform(\n",
" shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=\"float32\"\n",
" ),\n",
" dtype=\"int32\",\n",
" )\n",
"\n",
" highres_width = lowres_width * scale\n",
Expand All @@ -184,7 +197,8 @@
" highres_width : highres_width + hr_crop_size,\n",
" ] # 96x96\n",
"\n",
" return lowres_img_cropped, highres_img_cropped\n"
" return lowres_img_cropped, highres_img_cropped\n",
""
]
},
{
Expand All @@ -202,15 +216,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def dataset_object(dataset_cache, training=True):\n",
"\n",
" ds = dataset_cache\n",
" ds = ds.map(\n",
" lambda lowres, highres: random_crop(lowres, highres, scale=4),\n",
Expand Down Expand Up @@ -248,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -277,7 +290,8 @@
" \"\"\"Compute the peak signal-to-noise ratio, measures quality of image.\"\"\"\n",
" # Max value of pixel is 255\n",
" psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]\n",
" return psnr_value\n"
" return psnr_value\n",
""
]
},
{
Expand Down Expand Up @@ -305,14 +319,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class EDSRModel(tf.keras.Model):\n",
"class EDSRModel(keras.Model):\n",
" def train_step(self, data):\n",
" # Unpack the data. Its structure depends on your model and\n",
" # on what you pass to `fit()`.\n",
Expand All @@ -336,16 +350,16 @@
"\n",
" def predict_step(self, x):\n",
" # Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast\n",
" x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)\n",
" x = ops.cast(tf.expand_dims(x, axis=0), dtype=\"float32\")\n",
" # Passing low resolution image to model\n",
" super_resolution_img = self(x, training=False)\n",
" # Clips the tensor from min(0) to max(255)\n",
" super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)\n",
" super_resolution_img = ops.clip(super_resolution_img, 0, 255)\n",
" # Rounds the values of a tensor to the nearest integer\n",
" super_resolution_img = tf.round(super_resolution_img)\n",
" super_resolution_img = ops.round(super_resolution_img)\n",
" # Removes dimensions of size 1 from the shape of a tensor and converting to uint8\n",
" super_resolution_img = tf.squeeze(\n",
" tf.cast(super_resolution_img, tf.uint8), axis=0\n",
" super_resolution_img = ops.squeeze(\n",
" ops.cast(super_resolution_img, dtype=\"uint8\"), axis=0\n",
" )\n",
" return super_resolution_img\n",
"\n",
Expand All @@ -360,10 +374,10 @@
"\n",
"# Upsampling Block\n",
"def Upsampling(inputs, factor=2, **kwargs):\n",
" x = layers.Conv2D(64 * (factor ** 2), 3, padding=\"same\", **kwargs)(inputs)\n",
" x = tf.nn.depth_to_space(x, block_size=factor)\n",
" x = layers.Conv2D(64 * (factor ** 2), 3, padding=\"same\", **kwargs)(x)\n",
" x = tf.nn.depth_to_space(x, block_size=factor)\n",
" x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(inputs)\n",
" x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n",
" x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(x)\n",
" x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n",
" return x\n",
"\n",
"\n",
Expand Down Expand Up @@ -402,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -431,7 +445,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -473,7 +487,7 @@
"\n",
"| Trained Model | Demo |\n",
"| :--: | :--: |\n",
"| [![Generic badge](https://img.shields.io/badge/🤗%20Model-EDSR-red.svg)](https://huggingface.co/keras-io/EDSR) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-EDSR-red.svg)](https://huggingface.co/spaces/keras-io/EDSR) |"
"| [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Model-EDSR-red.svg)](https://huggingface.co/keras-io/EDSR) | [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Spaces-EDSR-red.svg)](https://huggingface.co/spaces/keras-io/EDSR) |"
]
}
],
Expand Down Expand Up @@ -506,4 +520,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading
Loading