6 min read

Today there are a variety of tools available at your disposal to develop and train your own Reinforcement learning agent. In this tutorial, we are going to learn about a Keras-RL agent called CartPole. We will go through this example because it won’t consume your GPU, and your cloud budget to run. Also, this logic can be easily extended to other Atari problems.

This article is an excerpt taken from the book Deep Learning Quick Reference, written by Mike Bernico.

Let’s talk quickly about the CartPole environment first:

  • CartPole: The CartPole environment consists of a pole, balanced on a cart. The agent has to learn how to balance the pole vertically, while the cart underneath it moves. The agent is given the position of the cart, the velocity of the cart, the angle of the pole, and the rotational rate of the pole as inputs. The agent can apply a force on either side of the cart. If the pole falls more than 15 degrees from vertical, it’s game over for our agent.

The CartPole agent will use a fairly modest neural network that you should be able to train fairly quickly even without a GPU. We will start by looking at the model architecture. Then we will define the network’s memory, exploration policy, and finally, train the agent.

CartPole neural network architecture

Three hidden layers with 16 neurons each are more than enough to solve this simple problem. We will use the following code to define the model:

def build_model(state_size, num_actions):
    input = Input(shape=(1,state_size))
    x = Flatten()(input)
    x = Dense(16, activation='relu')(x)
    x = Dense(16, activation='relu')(x)
    x = Dense(16, activation='relu')(x)
    output = Dense(num_actions, activation='linear')(x)
    model = Model(inputs=input, outputs=output)
    return model

The input will be a 1 x state space vector and there will be an output neuron for each possible action that will predict the Q value of that action for each step. By taking the argmax of the outputs, we can choose the action with the highest Q value, but we don’t have to do that ourselves as Keras-RL will do it for us.

Keras-RL Memory

Keras-RL provides us with a class called rl.memory.SequentialMemory that provides a fast and efficient data structure that we can store the agent’s experiences in:

memory = SequentialMemory(limit=50000, window_length=1)

We need to specify a maximum size for this memory object, which is a hyperparameter. As new experiences are added to this memory and it becomes full, old experiences are forgotten.

Keras-RL Policy

Keras-RL provides an -greedy Q Policy called rl.policy.EpsGreedyQPolicy that we can use to balance exploration and exploitation. We can use rl.policy.LinearAnnealedPolicy to decay our  as the agent steps forward in the world, as shown in the following code:

policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05, nb_steps=10000)

Here we’re saying that we want to start with a value of 1 for  and go no smaller than 0.1, while testing if our random number is less than 0.05. We set the number of steps between 1 and .1 to 10,000 and Keras-RL handles the decay math for us.


With a model, memory, and policy defined, we’re now ready to create a deep Q network Agent and send that agent those objects. Keras-RL provides an agent class called rl.agents.dqn.DQNAgent that we can use for this, as shown in the following code:

dqn = DQNAgent(model=model, nb_actions=num_actions, memory=memory, nb_steps_warmup=10,
               target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

Two of these parameters are probably unfamiliar at this point, target_model_update and nb_steps_warmup:

  • nb_steps_warmup: Determines how long we wait before we start doing experience replay, which if you recall, is when we actually start training the network. This lets us build up enough experience to build a proper minibatch. If you choose a value for this parameter that’s smaller than your batch size, Keras RL will sample with a replacement.
  • target_model_update: The Q function is recursive and when the agent updates it’s network for Q(s,a) that update also impacts the prediction it will make for Q(s’, a). This can make for a very unstable network. The way most deep Q network implementations address this limitation is by using a target network, which is a copy of the deep Q network that isn’t trained, but rather replaced with a fresh copy every so often. The target_model_update parameter controls how often this happens.

Keras-RL Training

Keras-RL provides several Keras-like callbacks that allow for convenient model checkpointing and logging. We will use both of those callbacks below. If you would like to see more of the callbacks Keras-RL provides, they can be found here: https://github.com/matthiasplappert/keras-rl/blob/master/rl/callbacks.py. You can also find a Callback class that you can use to create your own Keras-RL callbacks.

We will use the following code to train our model:

def build_callbacks(env_name):
    checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format(env_name)
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=5000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    return callbacks
callbacks = build_callbacks(ENV_NAME)

dqn.fit(env, nb_steps=50000,

Once the agent’s callbacks are built, we can fit the DQNAgent by using a .fit() method. Take note of the visualize parameter in this example. If visualize were set to True, we would be able to watch the agent interact with the environment as we went. However, this significantly slows down the training.


After the first 250 episodes, we will see that the total rewards for the episode approach 200 and the episode steps also approach 200. This means that the agent has learned to balance the pole on the cart until the environment ends at a maximum of 200 steps.

It’s of course fun to watch our success, so we can use the DQNAgent .test() method to evaluate for some number of episodes. The following code is used to define this method:

dqn.test(env, nb_episodes=5, visualize=True)

Here we’ve set visualize=True so we can watch our agent balance the pole, as shown in the following image:

balancing cartpole

There we go, that’s one balanced pole! Alright, I know, I’ll admit that balancing a pole on a cart isn’t all that cool, but it’s a good enough demonstration of the process! Hopefully, you have now understood the dynamics behind the process, and as we discussed earlier, the solution to this problem can be applied to other similar game-based problems.

If you found this article to be useful, make sure you check out the book Deep Learning Quick Reference to understand the other different types of reinforcement models you can build using Keras.

Read Next

Top 5 tools for reinforcement learning

DeepCube: A new deep reinforcement learning approach solves the Rubik’s cube with no human help

OpenAI builds reinforcement learning based system giving robots human like dexterity

Data Science Enthusiast. A massive science fiction and Manchester United fan. Loves to read, write and listen to music.



Please enter your comment!
Please enter your name here