How to build a neural network to fill the missing part of a handwritten digit using GANs [Tutorial]

17 min read

GANs are neural networks used in unsupervised learning that generate synthetic data given certain input data. GAN’s have two components: a generator and a discriminator. A generator generates new instances of an object and the discriminator determines whether the new instance belongs to the actual dataset. A generative learn how the data is generated i.e. the structure of the data, in order to categorize it. This allows the system to generate samples with similar statistical properties. Discriminative models will learn the relation between the data and the label associated with the data. The discriminative model will categorize the input data without knowing how the data is generated. GAN exploits the concept behind both the models to get a better network architecture.

This tutorial on GAN’s will help you build a neural network that fills in the missing part of a handwritten digit. This tutorial will cover how to build an MNIST digit classifier and simulate a dataset of handwritten digits with sections of the handwritten numbers missing. Next, users will learn using the MNIST classifier to predict on noised/masked MNIST digits dataset (simulated dataset) and implement GAN to generate back the missing regions of the digit. This tutorial will also cover using the MNIST classifier to predict on the generated digits from GAN and finally compare performance between masked data and generated data.

This tutorial is an excerpt from a book written by  Matthew Lamons, Rahul Kumar, Abhishek Nagaraja titled Python Deep Learning Projects.
This book will help users develop their own deep learning systems in a straightforward way and in an efficient way. The book has projects developed using complex deep learning projects in the field of computational linguistics and computer vision to help users master the subject.

All of the Python files and Jupyter Notebook files for this tutorial can be found at GitHub.

In this tutorial, we will be using the Keras deep learning library.

Importing all of the dependencies

We will be using numpy, matplotlib, keras, tensorflow, and the tqdm package in this exercise. Here, TensorFlow is used as the backend for Keras. You can install these packages with pip. For the MNIST data, we will be using the dataset available in the keras module with a simple import:

import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

from keras.layers import Input, Conv2D
from keras.layers import AveragePooling2D, BatchNormalization
from keras.layers import UpSampling2D, Flatten, Activation
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from keras import backend as k

from keras.datasets import mnist

It is important that you set seed for reproducibility:

# set seed for reproducibility
seed_val = 9000

Exploring the data

We will load the MNIST data into our session from the keras module with mnist.load_data(). After doing so, we will print the shape and the size of the dataset, as well as the number of classes and unique labels in the dataset:

(X_train, y_train), (X_test, y_test) =  mnist.load_data()
print('Size of the training_set: ', X_train.shape)
print('Size of the test_set: ', X_test.shape)
print('Shape of each image: ', X_train[0].shape)
print('Total number of classes: ', len(np.unique(y_train)))
print('Unique class labels: ', np.unique(y_train))

We have a dataset with 10 different classes and 60,000 images, with each image having a shape of 28*28 and each class having 6,000 images.

Let’s plot and see what the handwritten images look like:

# Plot of 9 random images
for i in range(0, 9):
    plt.subplot(331+i) # plot of 3 rows and 3 columns
    plt.axis('off') # turn off axis
    plt.imshow(X_train[i], cmap='gray') # gray scale

The output is as follows:

Let’s plot a handwritten digit from each class:

# plotting image from each class
fig=plt.figure(figsize=(8, 4))
columns = 5
rows = 2
for i in range(0, rows*columns):
    fig.add_subplot(rows, columns, i+1)
    plt.title(str(i)) # label 
    plt.axis('off') # turn off axis
    plt.imshow(X_train[np.where(y_train==i)][0], cmap='gray') # gray scale

The output is as follows:

Look at the maximum and the minimum pixel value in the dataset:

print('Maximum pixel value in the training_set: ', np.max(X_train))
print('Minimum pixel value in the training_set: ', np.min(X_train))

The output is as follows:

Preparing the data

Type conversion, centering, scaling, and reshaping are some of the pre-processing we will implement in this tutorial.

Type conversion, centering and scaling

Set the type to np.float32.

For centering, we subtract the dataset by 127.5. The values in the dataset will now range between -127.5 to 127.5.

For scaling, we divide the centered dataset by half of the maximum pixel value in the dataset, that is, 255/2. This will result in a dataset with values ranging between -1 and 1:

# Converting integer values to float types 
X_train = X_train.astype(np.float32)
X_test = X_test.astype(np.float32)
# Scaling and centering
X_train = (X_train - 127.5) / 127.5
X_test = (X_test - 127.5)/ 127.5
print('Maximum pixel value in the training_set after Centering and Scaling: ', np.max(X_train))
print('Minimum pixel value in the training_set after Centering and Scaling: ', np.min(X_train))

Let’s define a function to rescale the pixel values of the scaled image to range between 0 and 255:

# Rescale the pixel values (0 and 255)
def upscale(image):
    return (image*127.5 + 127.5).astype(np.uint8)
# Lets see if this works
z = upscale(X_train[0])
print('Maximum pixel value after upscaling scaled image: ',np.max(z))
print('Maximum pixel value after upscaling scaled image: ',np.min(z))

A plot of 9 centered and scaled images after upscaling:

for i in range(0, 9):
    plt.subplot(331+i) # plot of 3 rows and 3 columns
    plt.axis('off') # turn off axis
    plt.imshow(upscale(X_train[i]), cmap='gray') # gray scale

The output is as follows:

Masking/inserting noise

For the needs of this project, we need to simulate a dataset of incomplete digits. So, let’s write a function to mask small regions in the original image to form the noised dataset.

The idea is to mask an 8*8 region of the image with the top-left corner of the mask falling between the 9th and 13th pixel (between index 8 and 12) along both the x and y axis of the image. This is to make sure that we are always masking around the center part of the image:

def noising(image):
    array = np.array(image)
    i = random.choice(range(8,12)) # x coordinate for the top left corner of the mask
    j = random.choice(range(8,12)) # y coordinate for the top left corner of the mask
    array[i:i+8, j:j+8]=-1.0 # setting the pixels in the masked region to -1
    return array
noised_train_data = np.array([*map(noising, X_train)])
noised_test_data = np.array([*map(noising, X_test)])
print('Noised train data Shape/Dimension : ', noised_train_data.shape)
print('Noised test data Shape/Dimension : ', noised_train_data.shape)

A plot of 9 scaled noised images after upscaling:

# Plot of 9 scaled noised images after upscaling
for i in range(0, 9):
    plt.subplot(331+i) # plot of 3 rows and 3 columns
    plt.axis('off') # turn off axis
    plt.imshow(upscale(noised_train_data[i]), cmap='gray') # gray scale

The output is as follows:


Reshape the original dataset and the noised dataset to a shape of 60000*28*28*1. This is important since the 2D convolutions expect to receive images of a shape of 28*28*1:

# Reshaping the training data
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)
print('Size/Shape of the original training set: ', X_train.shape)
# Reshaping the noised training data
noised_train_data = noised_train_data.reshape(noised_train_data.shape[0],
noised_train_data.shape[2], 1)
print('Size/Shape of the noised training set: ', noised_train_data.shape)

# Reshaping the testing data
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], X_test.shape[2], 1)
print('Size/Shape of the original test set: ', X_test.shape)

# Reshaping the noised testing data
noised_test_data = noised_test_data.reshape(noised_test_data.shape[0],
noised_test_data.shape[2], 1)
print('Size/Shape of the noised test set: ', noised_test_data.shape)

MNIST classifier

To start off with modeling, let’s build a simple convolutional neural network (CNN)
digit classifier.

The first layer is a convolution layer that has 32 filters of a shape of 3*3, with relu activation and Dropout as the regularizer. The second layer is a convolution layer that has 64 filters of a shape of 3*3, with relu activation and Dropout as the regularizer. The third layer is a convolution layer that has 128 filters of a shape of 3*3, with relu activation and Dropout as the regularizer, which is finally flattened. The fourth layer is a Dense layer of 1024 neurons with relu activation. The final layer is a Dense layer with 10 neurons corresponding to the 10 classes in the MNIST dataset, and the activation used here is softmaxbatch_size is set to 128, the optimizer used is adam, and validation_split is set to 0.2. This means that 20% of the training set will be used as the validation set:

# input image shape
input_shape = (28,28,1)
def train_mnist(input_shape, X_train, y_train):
model = Sequential()
model.add(Conv2D(32, (3, 3), strides=2, padding='same',

model.add(Conv2D(64, (3, 3), strides=2, padding='same'))

model.add(Conv2D(128, (3, 3), padding='same'))

model.add(Dense(1024, activation = 'relu'))
model.add(Dense(10, activation='softmax'))
model.compile(loss = 'sparse_categorical_crossentropy',
optimizer = 'adam', metrics = ['accuracy']), y_train, batch_size = 128, 
epochs = 3, validation_split=0.2, verbose = 1 )
return model

mnist_model = train_mnist(input_shape, X_train, y_train)

The output is as follows:

Use the built CNN digit classifier on the masked images to get a measure of its performance on digits that are missing small sections:

# prediction on the masked images
pred_labels = mnist_model.predict_classes(noised_test_data)
print('The model model accuracy on the masked images is:',np.mean(pred_labels==y_test)*100)

On the masked images, the CNN digit classifier is 74.9% accurate. It might be slightly different when you run it, but it will still be very close.

Defining hyperparameters for GAN

The following are some of the hyperparameters defined that we will be using throughout the code and are totally configurable:

# Smoothing value
smooth_real = 0.9
# Number of epochs
epochs = 5

# Batchsize
batch_size = 128

# Optimizer for the generator
optimizer_g = Adam(lr=0.0002, beta_1=0.5)

# Optimizer for the discriminator
optimizer_d = Adam(lr=0.0004, beta_1=0.5)

# Shape of the input image
input_shape = (28,28,1)

Building the GAN model components

With the idea that the final GAN model will be able to fill in the part of the image that is missing (masked), let’s define the generator. You can understand how to define the generator, discriminator, and DCGAN by referring to our book.

Training GAN

We’ve built the components of the GAN.  Let’s train the model in the next steps!

Plotting the training – part 1

During each epoch, the following function plots 9 generated images. For comparison, it will also plot the corresponding 9 original target images and 9 noised input images. We need to use the upscale function we’ve defined when plotting to make sure the images are scaled to range between 0 and 255, so that you do not encounter issues when plotting:

def generated_images_plot(original, noised_data, generator):
for i in range(9):
plt.subplot(331 + i)
plt.imshow(upscale(np.squeeze(noised_data[i])), cmap='gray') # upscale for plotting

for i in range(9):
pred = generator.predict(noised_data[i:i+1], verbose=0)
plt.subplot(331 + i)
plt.imshow(upscale(np.squeeze(pred[0])), cmap='gray') # upscale to avoid plotting errors

for i in range(9):
plt.subplot(331 + i)
plt.imshow(upscale(np.squeeze(original[i])), cmap='gray') # upscale for plotting

The output of this function is as follows:

Plotting the training – part 2

Let’s define another function that plots the images generated during each epoch. To reflect the difference, we will also include the original and the masked/noised images in the plot.

The top row contains the original images, the middle row contains the masked images, and the bottom row contains the generated images.

The plot has 12 rows with the sequence, row 1 – original, row 2 – masked, row3 – generated, row 4 – original, row5 – masked,…, row 12 – generated.

Let’s take a look at the code for the same:

def plot_generated_images_combined(original, noised_data, generator):
    rows, cols = 4, 12
    num = rows * cols
    image_size = 28
generated_images = generator.predict(noised_data[0:num])

imgs = np.concatenate([original[0:num], noised_data[0:num], generated_images])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = upscale(imgs)
plt.title('Original Images: top rows, '
'Corrupted Input: middle rows, '
'Generated Images: bottom rows')
plt.imshow(imgs, cmap='gray')

The output is as follows:

Training loop

Now we are at the most important part of the code; the part where all of the functions we previously defined will be used. The following are the steps:

  1. Load the generator by calling the img_generator() function.
  2. Load the discriminator by calling the img_discriminator() function and compile it with the binary cross-entropy loss and optimizer as optimizer_d, which we have defined under the hyperparameters section.
  3. Feed the generator and the discriminator to the dcgan() function and compile it with the binary cross-entropy loss and optimizer as optimizer_g, which we have defined under the hyperparameters section.
  4. Create a new batch of original images and masked images. Generate new fake images by feeding the batch of masked images to the generator.
  5. Concatenate the original and generated images so that the first 128 images are all original and the next 128 images are all fake. It is important that you do not shuffle the data here, otherwise it will be hard to train. Label the generated images as 0 and original images as 0.9 instead of 1. This is one-sided label smoothing on the original images. The reason for using label smoothing is to make the network resilient to adversarial examples. It’s called one-sided because we are smoothing labels only for the real images.
  6. Set discriminator.trainable to True to enable training of the discriminator and feed this set of 256 images and their corresponding labels to the discriminator for classification.
  7. Now, set discriminator.trainable to False and feed a new batch of 128 masked images labeled as 1 to the GAN (DCGAN) for classification. It is important to set discriminator.trainable to False to make sure the discriminator is not getting trained while training the generator.
  8. Repeat steps 4 through 7 for the desired number of epochs.

We have placed the plot_generated_images_combined() function and the generated_images_plot() function to  get a plot generated by both functions after the first iteration in the first epoch and after the end of each epoch.

Feel free to place these plot functions according to the frequency of plots you need displayed:

def train(X_train, noised_train_data,
          input_shape, smooth_real,
          epochs, batch_size,
          optimizer_g, optimizer_d):
# define two empty lists to store the discriminator
 # and the generator losses
discriminator_losses = []
generator_losses = []

# Number of iteration possible with batches of size 128
iterations = X_train.shape[0] // batch_size

# Load the generator and the discriminator
generator = img_generator(input_shape)
discriminator = img_discriminator(input_shape)

# Compile the discriminator with binary_crossentropy loss

# Feed the generator and the discriminator to the function dcgan
 # to form the DCGAN architecture
gan = dcgan(discriminator, generator, input_shape)

# Compile the DCGAN with binary_crossentropy loss
gan.compile(loss='binary_crossentropy', optimizer=optimizer_g)

for i in range(epochs):
print ('Epoch %d' % (i+1))
# Use tqdm to get an estimate of time remaining
for j in tqdm(range(1, iterations+1)):

# batch of original images (batch = batchsize)
original = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]

# batch of noised images (batch = batchsize)
noise = noised_train_data[np.random.randint(0, noised_train_data.shape[0], size=batch_size)]

# Generate fake images
generated_images = generator.predict(noise)

# Labels for generated data
dis_lab = np.zeros(2*batch_size)

# data for discriminator
dis_train = np.concatenate([original, generated_images])

# label smoothing for original images
dis_lab[:batch_size] = smooth_real

# Train discriminator on original images
discriminator.trainable = True
discriminator_loss = discriminator.train_on_batch(dis_train, dis_lab)

# save the losses 

# Train generator
gen_lab = np.ones(batch_size)
discriminator.trainable = False
sample_indices = np.random.randint(0, X_train.shape[0], size=batch_size)
original = X_train[sample_indices]
noise = noised_train_data[sample_indices]

generator_loss = gan.train_on_batch(noise, gen_lab)

# save the losses

if i == 0 and j == 1:
print('Iteration - %d', j)
generated_images_plot(original, noise, generator)
plot_generated_images_combined(original, noise, generator)

print("Discriminator Loss: ", discriminator_loss,\
", Adversarial Loss: ", generator_loss)

# training plot 1
generated_images_plot(original, noise, generator)
# training plot 2
plot_generated_images_combined(original, noise, generator)

# plot the training losses
plt.plot(range(len(discriminator_losses)), discriminator_losses,
color='red', label='Discriminator loss')
plt.plot(range(len(generator_losses)), generator_losses,
color='blue', label='Adversarial loss')
plt.title('Discriminator and Adversarial loss')
plt.ylabel('Loss (Adversarial/Discriminator)')

return generator

generator = train(X_train, noised_train_data,
input_shape, smooth_real,
epochs, batch_size,
optimizer_g, optimizer_d)

The output is as follows:

 Generated images plotted with training plots at the end of the first iteration of epoch 1
 Generated images plotted with training plots at the end of epoch 2
 Generated images plotted with training plots at the end of epoch 5
 Plot of the discriminator and adversarial loss during training


CNN classifier predictions on the noised and generated images

We will call the generator on the masked MNIST test data to generate images, that is, fill in the missing part of the digits:

# restore missing parts of the digit with the generator
gen_imgs_test = generator.predict(noised_test_data)

Then, we will pass the generated MNIST digits to the digit classifier we have modeled already:

# predict on the restored/generated digits
gen_pred_lab = mnist_model.predict_classes(gen_imgs_test)
print('The model model accuracy on the generated images is:',np.mean(gen_pred_lab==y_test)*100)

The MNIST CNN classifier is 87.82% accurate on the generated data.

The following is a plot showing 10 generated images by the generator, the actual label of the generated image, and the label predicted by the digit classifier after processing the generated image:

# plot of 10 generated images and their predicted label
fig=plt.figure(figsize=(8, 4))
plt.title('Generated Images')
columns = 5
rows = 2
for i in range(0, rows*columns):
    fig.add_subplot(rows, columns, i+1)
    plt.title('Act: %d, Pred: %d'%(gen_pred_lab[i],y_test[i])) # label 
    plt.axis('off') # turn off axis
    plt.imshow(upscale(np.squeeze(gen_imgs_test[i])), cmap='gray') # gray scale

The output is as follows:

The Jupyter Notebook code files for the preceding DCGAN MNIST inpainting can be found at GitHub. Use the Jupyter Notebook code files for the DCGAN Fashion MNIST inpainting can be found.


We built a deep convolution GAN in Keras on handwritten MNIST digits and understood the function of the generator and the discriminator component of the GAN. We defined key hyperparameters, as well as, in some places, reasoned with why we used what we did. Finally, we tested the GAN’s performance on unseen data and determined that we succeeded in achieving our goals.

To understand insightful projects to master deep learning and neural network architectures using Python and Keras, check out this book  Python Deep Learning Projects.

Read Next

Getting started with Web Scraping using Python [Tutorial]

Google researchers introduce JAX: A TensorFlow-like framework for generating high-performance code from Python and NumPy machine learning programs

Google releases Magenta studio beta, an open source python machine learning library for music artists