[box type=”note” align=”” class=”” width=””]The following excerpt is taken from Chapter 2 – Learning Features with Unsupervised Generative Networks of the book Deep Learning with Theano, written by Christopher Bourez. This book talks about modeling and training effective deep learning models with Theano, a popular Python-based deep learning library. [/box]
In this article, we introduce you to the concept of Generative Adversarial Networks, a popular class of Artificial Intelligence algorithms used in unsupervised machine learning. Code files for this particular chapter are available for download towards the end of the post.
Generative adversarial networks are composed of two models that are alternatively trained to compete with each other. The generator network G is optimized to reproduce the true data distribution, by generating data that is difficult for the discriminator D to differentiate from real data. Meanwhile, the second network D is optimized to distinguish real data and synthetic data generated by G. Overall, the training procedure is similar to a two-player min-max game with the following objective function:
Here, x is real data sampled from real data distribution, and z the noise vector of the generative model. In some ways, the discriminator and the generator can be seen as the police and the thief: to be sure the training works correctly, the police is trained twice as much as the thief. Let’s illustrate GANs with the case of images as data. In particular, let’s again take our example from Chapter 2, Classifying Handwritten Digits with a Feedforward Network about MNIST digits, and consider training a generative adversarial network, to generate images, conditionally on the digit we want.
The GAN method consists of training the generative model using a second model, the discriminative network, to discriminate input data between real and fake. In this case, we can simply reuse our MNIST image classification model as discriminator, with two classes, real or fake, for the prediction output, and also condition it on the label of the digit that is supposed to be generated. To condition the net on the label, the digit label is concatenated with the inputs:
def conv_cond_concat(x, y): return T.concatenate([x, y*T.ones((x.shape, y.shape, x.shape, x.shape))], axis=1) def discrim(X, Y, w, w2, w3, wy): yb = Y.dimshuffle(0, 1, 'x', 'x') X = conv_cond_concat(X, yb) h = T.nnet.relu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2)), alpha=0.2 ) h = conv_cond_concat(h, yb) h2 = T.nnet.relu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2))), alpha=0.2) h2 = T.flatten(h2, 2) h2 = T.concatenate([h2, Y], axis=1) h3 = T.nnet.relu(batchnorm(T.dot(h2, w3))) h3 = T.concatenate([h3, Y], axis=1) y = T.nnet.sigmoid(T.dot(h3, wy)) return y
|Note the use of two leaky rectified linear units, with a leak of 0.2, as activation for the first two convolutions.|
To generate an image given noise and label, the generator network consists of a stack of deconvolutions, using an input noise vector z that consists of 100 real numbers ranging from 0 to 1:
To create a deconvolution in Theano, a dummy convolutional forward pass is created, which gradient is used as deconvolution:
def deconv(X, w, subsample=(1, 1), border_mode=(0, 0), conv_ mode='conv'): img = gpu_contiguous(T.cast(X, 'float32')) kerns = gpu_contiguous(T.cast(w, 'float32')) desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, conv_mode=conv_mode)(gpu_alloc_empty(img.shape, kerns.shape, img.shape*subsample, img.shape*subsample).shape, kerns. shape) out = gpu_alloc_empty(img.shape, kerns.shape, img. shape*subsample, img.shape*subsample) d_img = GpuDnnConvGradI()(kerns, img, out, desc) return d_img def gen(Z, Y, w, w2, w3, wx): yb = Y.dimshuffle(0, 1, 'x', 'x') Z = T.concatenate([Z, Y], axis=1) h = T.nnet.relu(batchnorm(T.dot(Z, w))) h = T.concatenate([h, Y], axis=1) h2 = T.nnet.relu(batchnorm(T.dot(h, w2))) h2 = h2.reshape((h2.shape, ngf*2, 7, 7)) h2 = conv_cond_concat(h2, yb) h3 = T.nnet.relu(batchnorm(deconv(h2, w3, subsample=(2, 2), border_mode=(2, 2)))) h3 = conv_cond_concat(h3, yb) x = T.nnet.sigmoid(deconv(h3, wx, subsample=(2, 2), border_ mode=(2, 2))) return x
Real data is given by the tuple (X,Y), while generated data is built from noise and label (Z,Y):
X = T.tensor4() Z = T.matrix() Y = T.matrix() gX = gen(Z, Y, *gen_params) p_real = discrim(X, Y, *discrim_params) p_gen = discrim(gX, Y, *discrim_params)
Generator and discriminator models compete during adversarial learning:
- The discriminator is trained to label real data as real (1) and label generated data as generated (0), hence minimizing the following cost function:
d_cost = T.nnet.binary_crossentropy(p_real, T.ones(p_real.shape)).mean() + T.nnet.binary_crossentropy(p_gen, T.zeros(p_gen.shape)). mean()
- The generator is trained to deceive the discriminator as much as possible. The training signal for the generator is provided by the discriminator network (p_gen) to the generator:
g_cost = T.nnet.binary_crossentropy(p_gen,T.ones(p_gen.shape)). mean()
The same as usual follows. Cost with respect to the parameters for each model is computed and training optimizes the weights of each model alternatively, with two times more the discriminator. In the case of GANs, competition between discriminator and generator does not lead to decreases in each loss.
From the first epoch:
To the 45th epoch:
Generated examples look closer to real ones:
Generative models, and especially Generative Adversarial Networks are currently the trending areas of Deep Learning. It has also found its way in a few practical applications as well. For example, a generative model can successfully be trained to generate the next most likely video frames by learning the features of the previous frames. Another popular example where GANs can be used is, search engines that predict the next likely word before it is even entered by the user, by studying the sequence of the previously entered words.
If you found this excerpt useful, do check out more comprehensive coverage of popular deep learning topics in our book Deep Learning with Theano.
[box type=”download” align=”” class=”” width=””] Download files [/box]