Super-Resolution Generative Adversarial Network, or SRGAN, is a Generative Adversarial Network (GAN) that can generate super-resolution images from low-resolution images, with finer details and higher quality. CNNs were earlier used to produce high-resolution images that train quicker and achieve high-level accuracy. However, in some cases, they are incapable of recovering finer details and often generate blurry images.
In this tutorial, we will learn how to implement an SRGAN network in the Keras framework that will be capable of generating high-resolution images. SRGANs were introduced in the paper titled, Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network, by Christian Ledig, Lucas Theis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, and others.
This tutorial is an excerpt taken from the book Generative Adversarial Networks Projects written by Kailash Ahirwar. The book explores unsupervised techniques for training neural networks and includes seven end-to-end projects in the GAN domain.
Downloading the CelebA dataset
For this tutorial, we will use the large-scale CelebFaces Attributes (CelebA) dataset, which is available here. The dataset contains 202, 599 face images of celebrities.
We will use the CelebA dataset to train our SRGAN network. Perform the following steps to download and extract the dataset:
- Download the dataset from the following link:
https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAB06FXaQRUNtjW9ntaoPGvCa?dl=0
- Extract images from the downloaded img_align_celeba.zip by executing the following command:
unzip img_align_celeba.zip
We have now downloaded and extracted the dataset. We can now start working on the Keras implementation of SRGAN.
The Keras implementation of SRGAN
SRGAN has three neural networks, a generator, a discriminator, and a pre-trained VGG19 network on the Imagenet dataset. In this section, we will write the implementation for all the networks. Let’s start by implementing the generator network.
Before starting to write the implementations, create a Python file called main.py and import the essential modules, as follows:
import glob
import os
import numpy as np
import tensorflow as tf
from keras import Input
from keras.applications import VGG19
from keras.callbacks import TensorBoard
from keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense, PReLU, Flatten
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam
from keras_preprocessing.image import img_to_array, load_img
from scipy.misc import imsave
The generator network
Let’s start by writing the layers for the generator network in the Keras framework and then create a Keras model, using the functional API of the Keras framework.
Perform the following steps to implement the generator network in Keras:
- Start by defining the hyperparameters required for the generator network:
residual_blocks = 16
momentum = 0.8
input_shape = (64, 64, 3)
- Next, create an input layer to feed input to the network, as follows:
input_layer = Input(shape=input_shape)
- Next, add the pre-residual block (2D convolution layer), as follows:
Configuration:- Filters: 64
- Kernel size: 9
- Strides: 1
- Padding: same
- Activation: relu:
gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)
- Next, write a method with the entire code for the residual block, as shown here:
def residual_block(x):
"""
Residual block
"""
filters = [64, 64]
kernel_size = 3
strides = 1
padding = "same"
momentum = 0.8
activation = "relu"
res = Conv2D(filters=filters[0], kernel_size=kernel_size,
strides=strides, padding=padding)(x)
res = Activation(activation=activation)(res)
res = BatchNormalization(momentum=momentum)(res)
res = Conv2D(filters=filters[1], kernel_size=kernel_size,
strides=strides, padding=padding)(res)
res = BatchNormalization(momentum=momentum)(res)
# Add res and x
res = Add()([res, x])
return res
- Now, add 16 residual blocks using the residual_block function, defined in the last step:
res = residual_block(gen1)
for i in range(residual_blocks - 1):
res = residual_block(res)
The output of the pre-residual block goes to the first residual block. The output of the first residual block goes to the second residual block, and so on, up to the 16th residual block.
- Next, add the post-residual block (a 2D convolution layer followed by a batch normalization layer), as follows:
Configuration:- Filters: 64
- Kernel size: 3
- Strides: 1
- Padding: same
- Batchnormalization: Yes (momentum=0.8):
gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
gen2 = BatchNormalization(momentum=momentum)(gen2)
- Now, add an Add layer to take the sum of the output from the pre-residual block, which is gen1, and the output from the post-residual block, which is gen2. This layer generates another tensor of similar shape.
gen3 = Add()([gen2, gen1])
- Next, add an upsampling block, as follows:
Configuration:- Upsampling size: 2
- Filers: 256
- Kernel size: 3
- Strides: 1
- Padding: same
- Activation: PReLU:
gen4 = UpSampling2D(size=2)(gen3)
gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
gen4 = Activation('relu')(gen4)
- Next, add another upsampling block, as follows:
Configuration:- Upsampling size: 2
- Filers: 256
- Kernel size: 3
- Strides: 1
- Padding: same
- Activation: PReLU:
gen5 = UpSampling2D(size=2)(gen4)
gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
gen5 = Activation('relu')(gen5)
- Finally, add the output convolution layer:
Configuration:- Filters: 3 (equal to number of channels)
- Kernel size: 9
- Strides: 1
- Padding: same
- Activation: tanh:
gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
output = Activation('tanh')(gen6)
Once you have defined all the layers in the network, you can create a Keras model. We have defined a Keras sequential graph using Keras’s functional API. Let’s create a Keras model by specifying the input and output for the network.
- Now, create a Keras model and specify the inputs and the outputs for the model, as follows:
model = Model(inputs=[input_layer], outputs=[output], name='generator')
We have successfully created a Keras model for the generator network. Now wrap the entire code for the generator network inside a Python function, as follows:
def build_generator():
"""
Create a generator network using the hyperparameter values defined below
:return:
"""
residual_blocks = 16
momentum = 0.8
input_shape = (64, 64, 3)
# Input Layer of the generator network
input_layer = Input(shape=input_shape)
# Add the pre-residual block
gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same',
activation='relu')(input_layer)
# Add 16 residual blocks
res = residual_block(gen1)
for i in range(residual_blocks - 1):
res = residual_block(res)
# Add the post-residual block
gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
gen2 = BatchNormalization(momentum=momentum)(gen2)
# Take the sum of the output from the pre-residual block(gen1) and
the post-residual block(gen2)
gen3 = Add()([gen2, gen1])
# Add an upsampling block
gen4 = UpSampling2D(size=2)(gen3)
gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
gen4 = Activation('relu')(gen4)
# Add another upsampling block
gen5 = UpSampling2D(size=2)(gen4)
gen5 = Conv2D(filters=256, kernel_size=3, strides=1,
padding='same')(gen5)
gen5 = Activation('relu')(gen5)
# Output convolution layer
gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
output = Activation('tanh')(gen6)
# Keras model
model = Model(inputs=[input_layer], outputs=[output],
name='generator')
return model
We have successfully created a Keras model for the generator network. In the next section, we will create a Keras model for the discriminator network.
The discriminator network
Let’s start by writing the layers for the discriminator network in the Keras framework and then create a Keras model, using the functional API of the Keras framework.
Perform the following steps to implement the discriminator network in Keras:
- Start by defining the hyperparameters required for the discriminator network:
leakyrelu_alpha = 0.2
momentum = 0.8
input_shape = (256, 256, 3)
- Next, create an input layer to feed input to the network, as follows:
input_layer = Input(shape=input_shape)
- Next, add a convolution block, as follows:
Configuration:- Filters: 64
- Kernel size: 3
- Strides: 1
- Padding: same
- Activation: LeakyReLU with alpha equal to 0.2:
dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)
- Next, add another seven convolution blocks, as follows:
Configuration:- Filters: 64, 128, 128, 256, 256, 512, 512
- Kernel size: 3, 3, 3, 3, 3, 3, 3
- Strides: 2, 1, 2, 1, 2, 1, 2
- Padding: same for each convolution layer
- Activation: LealyReLU with alpha equal to 0.2 for each convolution layer:
# Add the 2nd convolution block
dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
dis2 = BatchNormalization(momentum=momentum)(dis2)
# Add the third convolution block
dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
dis3 = BatchNormalization(momentum=momentum)(dis3)
# Add the fourth convolution block
dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
dis4 = BatchNormalization(momentum=0.8)(dis4)
# Add the fifth convolution block
dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
dis5 = BatchNormalization(momentum=momentum)(dis5)
# Add the sixth convolution block
dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
dis6 = BatchNormalization(momentum=momentum)(dis6)
# Add the seventh convolution block
dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
dis7 = BatchNormalization(momentum=momentum)(dis7)
# Add the eight convolution block
dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
dis8 = BatchNormalization(momentum=momentum)(dis8)
- Next, add a dense layer with 1,024 nodes, as follows:
Configuration:- Nodes: 1024
- Activation: LeakyReLU with alpha equal to 0.2:
dis9 = Dense(units=1024)(dis8)
dis9 = LeakyReLU(alpha=0.2)(dis9)
- Then, add a dense layer to return the probabilities, as follows:
output = Dense(units=1, activation='sigmoid')(dis9)
- Finally, create a Keras model and specify the inputs and the outputs for the network:
model = Model(inputs=[input_layer], outputs=[output],
name='discriminator')
Wrap the entire code for the discriminator network inside a function, as follows:
def build_discriminator():
"""
Create a discriminator network using the hyperparameter values defined below
:return:
"""
leakyrelu_alpha = 0.2
momentum = 0.8
input_shape = (256, 256, 3)
input_layer = Input(shape=input_shape)
# Add the first convolution block
dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)
# Add the 2nd convolution block
dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
dis2 = BatchNormalization(momentum=momentum)(dis2)
# Add the third convolution block
dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
dis3 = BatchNormalization(momentum=momentum)(dis3)
# Add the fourth convolution block
dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
dis4 = BatchNormalization(momentum=0.8)(dis4)
# Add the fifth convolution block
dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
dis5 = BatchNormalization(momentum=momentum)(dis5)
# Add the sixth convolution block
dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
dis6 = BatchNormalization(momentum=momentum)(dis6)
# Add the seventh convolution block
dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
dis7 = BatchNormalization(momentum=momentum)(dis7)
# Add the eight convolution block
dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
dis8 = BatchNormalization(momentum=momentum)(dis8)
# Add a dense layer
dis9 = Dense(units=1024)(dis8)
dis9 = LeakyReLU(alpha=0.2)(dis9)
# Last dense layer - for classification
output = Dense(units=1, activation='sigmoid')(dis9)
model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
return model
In this section, we have successfully created a Keras model for the discriminator network.
The adversarial network
The adversarial network is a combined network that uses the generator, the discriminator, and VGG19. In this section, we will create an adversarial network.
Perform the following steps to create an adversarial network:
- Start by creating an input layer for the network:
input_low_resolution = Input(shape=(64, 64, 3))
The adversarial network will receive an image of a shape of (64, 64, 3), which is why we have created an input layer.
- Next, generate fake high-resolution images using the generator network, as follows:
fake_hr_images = generator(input_low_resolution)
- Next, extract the features of the fake images using the VGG19 network, as follows:
fake_features = vgg(fake_hr_images)
- Next, make the discriminator network non-trainable in the adversarial network:
discriminator.trainable = False
We are making the discriminator network non-trainable because we don’t want to train the discriminator network while we train the generator network.
- Next, pass the fake images to the discriminator network:
output = discriminator(fake_hr_images)
- Finally, create a Keras model, which will be our adversarial model:
model = Model(inputs=[input_low_resolution], outputs=[output,
fake_features])
- Wrap the entire code for the adversarial model inside a Python function:
def build_adversarial_model(generator, discriminator, vgg):
input_low_resolution = Input(shape=(64, 64, 3))
fake_hr_images = generator(input_low_resolution)
fake_features = vgg(fake_hr_images)
discriminator.trainable = False
output = discriminator(fake_hr_images)
model = Model(inputs=[input_low_resolution],
outputs=[output, fake_features])
for layer in model.layers:
print(layer.name, layer.trainable)
print(model.summary())
return model
We have now successfully implemented the networks in Keras. Next, we train the network on the dataset that we downloaded.
Training the SRGAN
Training the SRGAN network is a two-step process. In the first step, we train the discriminator network. In the second step, we train the adversarial network, which eventually trains the generator network. Let’s start training the network.
Perform the following steps to train the SRGAN network:
- Start by defining the hyperparameters required for the training:
# Define hyperparameters
data_dir = "Paht/to/the/dataset/img_align_celeba/*.*"
epochs = 20000
batch_size = 1
# Shape of low-resolution and high-resolution images
low_resolution_shape = (64, 64, 3)
high_resolution_shape = (256, 256, 3)
- Next, define the training optimizer. For all networks, we will use Adam optimizer with the learning rate equal to 0.0002 and beta_1 equal to 0.5:
# Common optimizer for all networks
common_optimizer = Adam(0.0002, 0.5)
Building and compiling the networks
In this section, we will go through the different steps required to build and compile the networks:
- Build and compile the VGG19 network:
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=
['accuracy'])
To compile VGG19, use mse as the loss, accuracy as the metrics, and common_optimizer as the optimizer. Before compiling the network, disable the training, as we don’t want to train the VGG19 network.
- Next, build and compile the discriminator network, as follows:
discriminator = build_discriminator()
discriminator.compile(loss='mse', optimizer=common_optimizer,
metrics=['accuracy'])
To compile the discriminator network, use mse as the loss, accuracy as the metrics, and common_optimizer as the optimizer.
- Next, build the generator network:
generator = build_generator()
- Next, create an adversarial model. Start by creating two input layers:
input_high_resolution = Input(shape=high_resolution_shape)
input_low_resolution = Input(shape=low_resolution_shape)
- Next, use the generator network to symbolically generate high-resolution images from the low-resolution images:
generated_high_resolution_images = generator(input_low_resolution)
Use VGG19 to extract feature maps for the generated images:
features = vgg(generated_high_resolution_images)
Make the discriminator network non-trainable, because we don’t want to train the discriminator model during the training of the adversarial model:
discriminator.trainable = False
- Next, use the discriminator network to get the probabilities of the generated high-resolution fake images:
probs = discriminator(generated_high_resolution_images)
Here, probs represent the probability of the generated images belonging to a real dataset.
- Finally, create and compile the adversarial network:
adversarial_model = Model([input_low_resolution, input_high_resolution], [probs, features])
adversarial_model.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1], optimizer=common_optimizer)
To compile the adversarial model, use binary_crossentropy and mse as the loss functions, common_optimizer as the optimizer, and [0.001, 1] as the loss weights.
- Add Tensorboard to visualize the training losses and to visualize the network graphs:
tensorboard = TensorBoard(log_dir="logs/".format(time.time()))
tensorboard.set_model(generator)
tensorboard.set_model(discriminator)
- Create a loop that should run for the specified number of epochs:
for epoch in range(epochs):
print("Epoch:{}".format(epoch))
After this step, all of the code will be inside this for loop.
- Next, sample a batch of high-resolution and low-resolution images, as follows:
high_resolution_images, low_resolution_images =
sample_images(data_dir=data_dir,
batch_size=batch_size,low_resolution_shape=low_resolution_shape,
high_resolution_shape=high_resolution_shape)
The code for the sample_images function is as follows. It is quite descriptive and can be understood by going through it. It contains different steps to load and resize the images to generate high-resolution as well as low-resolution images:
def sample_images(data_dir, batch_size, high_resolution_shape, low_resolution_shape):
# Make a list of all images inside the data directory
all_images = glob.glob(data_dir)
# Choose a random batch of images
images_batch = np.random.choice(all_images, size=batch_size)
low_resolution_images = []
high_resolution_images = []
for img in images_batch:
# Get an ndarray of the current image
img1 = imread(img, mode='RGB')
img1 = img1.astype(np.float32)
# Resize the image
img1_high_resolution = imresize(img1, high_resolution_shape)
img1_low_resolution = imresize(img1, low_resolution_shape)
# Do a random flip
if np.random.random() < 0.5:
img1_high_resolution = np.fliplr(img1_high_resolution)
img1_low_resolution = np.fliplr(img1_low_resolution)
high_resolution_images.append(img1_high_resolution)
low_resolution_images.append(img1_low_resolution)
return np.array(high_resolution_images),
np.array(low_resolution_images)
- Next, normalize the images to convert the pixel values to a range between [-1, 1], as follows:
high_resolution_images = high_resolution_images / 127.5 - 1.
low_resolution_images = low_resolution_images / 127.5 - 1.
It is very important to convert the pixel values to a range of between -1 to 1. Our generator network has tanh at the end of the network. The tanh activation function squashes values to the same range. While calculating the loss, it is necessary to have all values in the same range. After this step is complete, we train the discriminator network, generator network, and then further visualize the images and evaluate the model.
In this tutorial, we learned how to download the CelebA dataset, and implemented the project in Keras before training the SRGAN. If you want to learn more about how to evaluate the trained SRGAN network, and optimizing the trained model, be sure to check out the book Generative Adversarial Networks Projects.
Read Next
Generative Adversarial Networks: Generate images using Keras GAN [Tutorial]
What you need to know about Generative Adversarial Networks
Generative Adversarial Networks (GANs): The next milestone In Deep Learning