15 min read

Super-Resolution Generative Adversarial Network, or SRGAN, is a Generative Adversarial Network (GAN) that can generate super-resolution images from low-resolution images, with finer details and higher quality. CNNs were earlier used to produce high-resolution images that train quicker and achieve high-level accuracy. However, in some cases, they are incapable of recovering finer details and often generate blurry images.

In this tutorial, we will learn how to implement an SRGAN network in the Keras framework that will be capable of generating high-resolution images.  SRGANs were introduced in the paper titled, Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network, by Christian Ledig, Lucas Theis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, and others.

This tutorial is an excerpt taken from the book Generative Adversarial Networks Projects written by Kailash Ahirwar. The book explores unsupervised techniques for training neural networks and includes seven end-to-end projects in the GAN domain.

Downloading the CelebA dataset

For this tutorial, we will use the large-scale CelebFaces Attributes (CelebA) dataset, which is available here. The dataset contains 202, 599 face images of celebrities.

The dataset is available for non-commercial research purposes only and can’t be used for commercial purposes. If you intend to use the dataset for commercial purposes, seek permissions from the owners of the images.

We will use the CelebA dataset to train our SRGAN network. Perform the following steps to download and extract the dataset:

  1. Download the dataset from the following link:
https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAB06FXaQRUNtjW9ntaoPGvCa?dl=0
  1. Extract images from the downloaded img_align_celeba.zip by executing the following command:
unzip img_align_celeba.zip

We have now downloaded and extracted the dataset. We can now start working on the Keras implementation of SRGAN.

The Keras implementation of SRGAN

SRGAN has three neural networks, a generator, a discriminator, and a pre-trained VGG19 network on the Imagenet dataset. In this section, we will write the implementation for all the networks. Let’s start by implementing the generator network.

Before starting to write the implementations, create a Python file called main.py and import the essential modules, as follows:

import glob
import os

import numpy as np
import tensorflow as tf
from keras import Input
from keras.applications import VGG19
from keras.callbacks import TensorBoard
from keras.layers import BatchNormalization, Activation, LeakyReLU, Add, Dense, PReLU, Flatten
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam
from keras_preprocessing.image import img_to_array, load_img
from scipy.misc import imsave

The generator network

Let’s start by writing the layers for the generator network in the Keras framework and then create a Keras model, using the functional API of the Keras framework.

Perform the following steps to implement the generator network in Keras:

  1. Start by defining the hyperparameters required for the generator network:
residual_blocks = 16
momentum = 0.8
input_shape = (64, 64, 3)
  1. Next, create an input layer to feed input to the network, as follows:
input_layer = Input(shape=input_shape)
The input layer takes an input image of a shape of (64, 64, 3) and passes it to the next layer in the network.
  1. Next, add the pre-residual block (2D convolution layer), as follows:
    Configuration:

    • Filters: 64
    • Kernel size: 9
    • Strides: 1
    • Padding: same
    • Activation: relu:
gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same', activation='relu')(input_layer)
  1. Next, write a method with the entire code for the residual block, as shown here:
def residual_block(x):
    """
    Residual block
    """
    filters = [64, 64]
    kernel_size = 3
    strides = 1
    padding = "same"
    momentum = 0.8
    activation = "relu"

    res = Conv2D(filters=filters[0], kernel_size=kernel_size, 
                 strides=strides, padding=padding)(x)
    res = Activation(activation=activation)(res)
    res = BatchNormalization(momentum=momentum)(res)

    res = Conv2D(filters=filters[1], kernel_size=kernel_size, 
                 strides=strides, padding=padding)(res)
    res = BatchNormalization(momentum=momentum)(res)

    # Add res and x
    res = Add()([res, x])
    return res
  1. Now, add 16 residual blocks using the residual_block function, defined in the last step:
res = residual_block(gen1)
for i in range(residual_blocks - 1):
    res = residual_block(res)

The output of the pre-residual block goes to the first residual block. The output of the first residual block goes to the second residual block, and so on, up to the 16th residual block.

  1. Next, add the post-residual block (a 2D convolution layer followed by a batch normalization layer), as follows:
    Configuration:

    • Filters: 64
    • Kernel size: 3
    • Strides: 1
    • Padding: same
    • Batchnormalization: Yes (momentum=0.8):
gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
gen2 = BatchNormalization(momentum=momentum)(gen2)
  1. Now, add an Add layer to take the sum of the output from the pre-residual block, which is gen1, and the output from the post-residual block, which is gen2. This layer generates another tensor of similar shape.
gen3 = Add()([gen2, gen1])
  1. Next, add an upsampling block, as follows:
    Configuration:

    • Upsampling size: 2
    • Filers: 256
    • Kernel size: 3
    • Strides: 1
    • Padding: same
    • Activation: PReLU:
gen4 = UpSampling2D(size=2)(gen3)
gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
gen4 = Activation('relu')(gen4)
  1. Next, add another upsampling block, as follows:
    Configuration:

    • Upsampling size2
    • Filers256
    • Kernel size3
    • Strides1
    • Paddingsame
    • ActivationPReLU:
gen5 = UpSampling2D(size=2)(gen4)
gen5 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen5)
gen5 = Activation('relu')(gen5)
  1. Finally, add the output convolution layer:
    Configuration:

    • Filters: 3 (equal to number of channels)
    • Kernel size: 9
    • Strides: 1
    • Padding: same
    • Activation: tanh:
gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
output = Activation('tanh')(gen6)

Once you have defined all the layers in the network, you can create a Keras model. We have defined a Keras sequential graph using Keras’s functional API. Let’s create a Keras model by specifying the input and output for the network.

  1. Now, create a Keras model and specify the inputs and the outputs for the model, as follows:
model = Model(inputs=[input_layer], outputs=[output], name='generator')

We have successfully created a Keras model for the generator network. Now wrap the entire code for the generator network inside a Python function, as follows:

def build_generator():
    """
    Create a generator network using the hyperparameter values defined below
    :return:
    """
    residual_blocks = 16
    momentum = 0.8
    input_shape = (64, 64, 3)

    # Input Layer of the generator network
    input_layer = Input(shape=input_shape)

    # Add the pre-residual block
    gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same',  
                  activation='relu')(input_layer)

    # Add 16 residual blocks
    res = residual_block(gen1)
    for i in range(residual_blocks - 1):
        res = residual_block(res)

    # Add the post-residual block
    gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
    gen2 = BatchNormalization(momentum=momentum)(gen2)

    # Take the sum of the output from the pre-residual block(gen1) and 
      the post-residual block(gen2)
    gen3 = Add()([gen2, gen1])

    # Add an upsampling block
    gen4 = UpSampling2D(size=2)(gen3)
    gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
    gen4 = Activation('relu')(gen4)

    # Add another upsampling block
    gen5 = UpSampling2D(size=2)(gen4)
    gen5 = Conv2D(filters=256, kernel_size=3, strides=1, 
                  padding='same')(gen5)
    gen5 = Activation('relu')(gen5)

    # Output convolution layer
    gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen5)
    output = Activation('tanh')(gen6)

    # Keras model
    model = Model(inputs=[input_layer], outputs=[output], 
                  name='generator')
    return model

We have successfully created a Keras model for the generator network. In the next section, we will create a Keras model for the discriminator network.

The discriminator network

Let’s start by writing the layers for the discriminator network in the Keras framework and then create a Keras model, using the functional API of the Keras framework.

Perform the following steps to implement the discriminator network in Keras:

  1. Start by defining the hyperparameters required for the discriminator network:
leakyrelu_alpha = 0.2
momentum = 0.8
input_shape = (256, 256, 3)
  1. Next, create an input layer to feed input to the network, as follows:
input_layer = Input(shape=input_shape)
  1. Next, add a convolution block, as follows:
    Configuration:

    • Filters: 64
    • Kernel size: 3
    • Strides: 1
    • Padding: same
    • Activation: LeakyReLU with alpha equal to 0.2:
dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)
  1. Next, add another seven convolution blocks, as follows:
    Configuration:

    • Filters: 64, 128, 128, 256, 256, 512, 512
    • Kernel size: 3, 3, 3, 3, 3, 3, 3
    • Strides: 2, 1, 2, 1, 2, 1, 2
    • Padding: same for each convolution layer
    • Activation: LealyReLU with alpha equal to 0.2 for each convolution layer:
# Add the 2nd convolution block
dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
dis2 = BatchNormalization(momentum=momentum)(dis2)

# Add the third convolution block
dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
dis3 = BatchNormalization(momentum=momentum)(dis3)

# Add the fourth convolution block
dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
dis4 = BatchNormalization(momentum=0.8)(dis4)

# Add the fifth convolution block
dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
dis5 = BatchNormalization(momentum=momentum)(dis5)

# Add the sixth convolution block
dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
dis6 = BatchNormalization(momentum=momentum)(dis6)

# Add the seventh convolution block
dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
dis7 = BatchNormalization(momentum=momentum)(dis7)

# Add the eight convolution block
dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
dis8 = BatchNormalization(momentum=momentum)(dis8)
  1. Next, add a dense layer with 1,024 nodes, as follows:
    Configuration:

    • Nodes: 1024
    • Activation: LeakyReLU with alpha equal to 0.2:
dis9 = Dense(units=1024)(dis8)
dis9 = LeakyReLU(alpha=0.2)(dis9)
  1. Then, add a dense layer to return the probabilities, as follows:
output = Dense(units=1, activation='sigmoid')(dis9)
  1. Finally, create a Keras model and specify the inputs and the outputs for the network:
model = Model(inputs=[input_layer], outputs=[output], 
              name='discriminator')

Wrap the entire code for the discriminator network inside a function, as follows:

def build_discriminator():
    """
    Create a discriminator network using the hyperparameter values defined below
    :return:
    """
    leakyrelu_alpha = 0.2
    momentum = 0.8
    input_shape = (256, 256, 3)

    input_layer = Input(shape=input_shape)

    # Add the first convolution block
    dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
    dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)

    # Add the 2nd convolution block
    dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
    dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
    dis2 = BatchNormalization(momentum=momentum)(dis2)

    # Add the third convolution block
    dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
    dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
    dis3 = BatchNormalization(momentum=momentum)(dis3)

    # Add the fourth convolution block
    dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
    dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
    dis4 = BatchNormalization(momentum=0.8)(dis4)

    # Add the fifth convolution block
    dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
    dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
    dis5 = BatchNormalization(momentum=momentum)(dis5)

    # Add the sixth convolution block
    dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
    dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
    dis6 = BatchNormalization(momentum=momentum)(dis6)

    # Add the seventh convolution block
    dis7 = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(dis6)
    dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
    dis7 = BatchNormalization(momentum=momentum)(dis7)

    # Add the eight convolution block
    dis8 = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(dis7)
    dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
    dis8 = BatchNormalization(momentum=momentum)(dis8)

    # Add a dense layer
    dis9 = Dense(units=1024)(dis8)
    dis9 = LeakyReLU(alpha=0.2)(dis9)

    # Last dense layer - for classification
    output = Dense(units=1, activation='sigmoid')(dis9)

    model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
    return model

In this section, we have successfully created a Keras model for the discriminator network.

The adversarial network

The adversarial network is a combined network that uses the generator, the discriminator, and VGG19. In this section, we will create an adversarial network.
Perform the following steps to create an adversarial network:

  1. Start by creating an input layer for the network:
input_low_resolution = Input(shape=(64, 64, 3))

The adversarial network will receive an image of a shape of (64, 64, 3), which is why we have created an input layer.

  1. Next, generate fake high-resolution images using the generator network, as follows:
fake_hr_images = generator(input_low_resolution)
  1. Next, extract the features of the fake images using the VGG19 network, as follows:
fake_features = vgg(fake_hr_images)
  1. Next, make the discriminator network non-trainable in the adversarial network:
discriminator.trainable = False

We are making the discriminator network non-trainable because we don’t want to train the discriminator network while we train the generator network.

  1. Next, pass the fake images to the discriminator network:
output = discriminator(fake_hr_images)
  1. Finally, create a Keras model, which will be our adversarial model:
model = Model(inputs=[input_low_resolution], outputs=[output, 
              fake_features])
  1. Wrap the entire code for the adversarial model inside a Python function:
def build_adversarial_model(generator, discriminator, vgg):

    input_low_resolution = Input(shape=(64, 64, 3))

    fake_hr_images = generator(input_low_resolution)
    fake_features = vgg(fake_hr_images)

    discriminator.trainable = False

    output = discriminator(fake_hr_images)

    model = Model(inputs=[input_low_resolution],
                  outputs=[output, fake_features])

    for layer in model.layers:
        print(layer.name, layer.trainable)

    print(model.summary())
    return model

We have now successfully implemented the networks in Keras. Next, we train the network on the dataset that we downloaded.

Training the SRGAN

Training the SRGAN network is a two-step process. In the first step, we train the discriminator network. In the second step, we train the adversarial network, which eventually trains the generator network. Let’s start training the network.
Perform the following steps to train the SRGAN network:

  1. Start by defining the hyperparameters required for the training:
# Define hyperparameters
data_dir = "Paht/to/the/dataset/img_align_celeba/*.*"
epochs = 20000
batch_size = 1

# Shape of low-resolution and high-resolution images
low_resolution_shape = (64, 64, 3)
high_resolution_shape = (256, 256, 3)
  1. Next, define the training optimizer. For all networks, we will use Adam optimizer with the learning rate equal to 0.0002 and beta_1 equal to 0.5:
# Common optimizer for all networks
common_optimizer = Adam(0.0002, 0.5)

Building and compiling the networks

In this section, we will go through the different steps required to build and compile the networks:

  1. Build and compile the VGG19 network:
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=
            ['accuracy'])

To compile VGG19, use mse as the loss, accuracy as the metrics, and common_optimizer as the optimizer. Before compiling the network, disable the training, as we don’t want to train the VGG19 network.

  1. Next, build and compile the discriminator network, as follows:
discriminator = build_discriminator()
discriminator.compile(loss='mse', optimizer=common_optimizer, 
                      metrics=['accuracy'])

To compile the discriminator network, use mse as the loss, accuracy as the metrics, and common_optimizer as the optimizer.

  1. Next, build the generator network:
generator = build_generator()
  1. Next, create an adversarial model. Start by creating two input layers:
input_high_resolution = Input(shape=high_resolution_shape)
input_low_resolution = Input(shape=low_resolution_shape)
  1. Next, use the generator network to symbolically generate high-resolution images from the low-resolution images:
generated_high_resolution_images = generator(input_low_resolution)

Use VGG19 to extract feature maps for the generated images:

features = vgg(generated_high_resolution_images)

Make the discriminator network non-trainable, because we don’t want to train the discriminator model during the training of the adversarial model:

discriminator.trainable = False
  1. Next, use the discriminator network to get the probabilities of the generated high-resolution fake images:
probs = discriminator(generated_high_resolution_images)

Here, probs represent the probability of the generated images belonging to a real dataset.

  1. Finally, create and compile the adversarial network:
adversarial_model = Model([input_low_resolution, input_high_resolution], [probs, features])
adversarial_model.compile(loss=['binary_crossentropy', 'mse'], 
            loss_weights=[1e-3, 1], optimizer=common_optimizer)

To compile the adversarial model, use binary_crossentropy and mse as the loss functions, common_optimizer as the optimizer, and [0.001, 1] as the loss weights.

  1. Add Tensorboard to visualize the training losses and to visualize the network graphs:
tensorboard = TensorBoard(log_dir="logs/".format(time.time()))
tensorboard.set_model(generator)
tensorboard.set_model(discriminator)
  1. Create a loop that should run for the specified number of epochs:
for epoch in range(epochs):
    print("Epoch:{}".format(epoch))

After this step, all of the code will be inside this for loop.

  1. Next, sample a batch of high-resolution and low-resolution images, as follows:
high_resolution_images, low_resolution_images = 
    sample_images(data_dir=data_dir,   
    batch_size=batch_size,low_resolution_shape=low_resolution_shape,                                                                
    high_resolution_shape=high_resolution_shape)

The code for the sample_images function is as follows. It is quite descriptive and can be understood by going through it. It contains different steps to load and resize the images to generate high-resolution as well as low-resolution images:

def sample_images(data_dir, batch_size, high_resolution_shape, low_resolution_shape):
    # Make a list of all images inside the data directory
    all_images = glob.glob(data_dir)

    # Choose a random batch of images
    images_batch = np.random.choice(all_images, size=batch_size)

    low_resolution_images = []
    high_resolution_images = []

    for img in images_batch:
        # Get an ndarray of the current image
        img1 = imread(img, mode='RGB')
        img1 = img1.astype(np.float32)

        # Resize the image
        img1_high_resolution = imresize(img1, high_resolution_shape)
        img1_low_resolution = imresize(img1, low_resolution_shape)

        # Do a random flip
        if np.random.random() < 0.5:
            img1_high_resolution = np.fliplr(img1_high_resolution)
            img1_low_resolution = np.fliplr(img1_low_resolution)

        high_resolution_images.append(img1_high_resolution)
        low_resolution_images.append(img1_low_resolution)

    return np.array(high_resolution_images), 
           np.array(low_resolution_images)
  1. Next, normalize the images to convert the pixel values to a range between [-1, 1], as follows:
high_resolution_images = high_resolution_images / 127.5 - 1.
    low_resolution_images = low_resolution_images / 127.5 - 1.

It is very important to convert the pixel values to a range of between -1 to 1. Our generator network has tanh at the end of the network. The tanh activation function squashes values to the same range. While calculating the loss, it is necessary to have all values in the same range. After this step is complete, we train the discriminator network, generator network, and then further visualize the images and evaluate the model.

In this tutorial, we learned how to download the CelebA dataset, and implemented the project in Keras before training the SRGAN.  If you want to learn more about how to evaluate the trained SRGAN network, and optimizing the trained model, be sure to check out the book Generative Adversarial Networks Projects.

Read Next

Generative Adversarial Networks: Generate images using Keras GAN [Tutorial]

What you need to know about Generative Adversarial Networks

Generative Adversarial Networks (GANs): The next milestone In Deep Learning

Tech writer at the Packt Hub. Dreamer, book nerd, lover of scented candles, karaoke, and Gilmore Girls.