OpenAI’s gradient checkpointing: A package that makes huge neural nets fit into memory

4 min read

OpenAI releases a python/Tensorflow package, Gradient checkpointing!

Gradient checkpointing lets you fit 10x larger neural nets into memory at the cost of an additional 20% computation time. The tools within this package, which is a joint development of Tim Salimans and Yaroslav Bulatov, aids in rewriting TensorFlow model for using less memory.

Computing the gradient of the loss by backpropagation is the memory intensive part of training deep neural networks. By checkpointing nodes in the computation graph defined by your model, and recomputing the parts of the graph in between those nodes during backpropagation, it is possible to calculate this gradient at reduced memory cost. While training deep feed-forward neural networks, which consists of n layers, we can reduce the memory consumption to O(sqrt(n)), at the cost of performing one additional forward pass.

The graph shows the amount of memory used while training TensorFlow official CIFAR10 resnet example with the regular tf.gradients function and the optimized gradient function.

To see how it works, let’s take an example of a simple feed-forward neural network.

In the figure above,

f : The activations of the neural network layers

b : Gradient of the loss with respect to the activations and parameters of these layers

All these nodes are evaluated in order during forward pass and in reversed order during backward pass. The results obtained for ‘f’ nodes are required in order to compute ‘b’ nodes. Hence, after the forward pass, all the f nodes are kept in memory, and can be erased only when backpropagation has progressed far enough to have computed all dependencies, or children, of an node. This implies that in simple backpropagation, the memory required grows linearly with the number of neural net layers n.

Graph 1: Vanilla Backpropagation

The graph above shows a simple vanilla backpropagation, which computes each node once. However, recomputing the nodes can save a lot of memory. For this, we can simply try recomputing every node from the forward pass as and when required.

The order of execution, and the memory used, then appear as follows:

Graph 2: Backpropagation with poor memory

By using the strategy above, the memory required to compute gradients in our graph is constant in the number of neural network layers n, which is optimal in terms of memory. However, now the number of node evaluations scales to n^2, which was previously scaled as n. This means, each of the n nodes is recomputed on the order of n times. As a result, the computation graph becomes much slower for evaluating deep networks. This makes the method impractical for use in deep learning.

To strike a balance between memory and computation, OpenAI has come up with a strategy that allows nodes to be recomputed, but not too often. The strategy used here is to mark a subset of the neural net activations as checkpoint nodes.

Source: Graph with chosen checkpointed node

These checkpoint nodes are kept in memory after the forward pass, while the remaining nodes are recomputed at most once. After recomputation, the non-checkpoint nodes are stored in memory until they are no longer required. For the case of a simple feed-forward neural net, all neuron activation nodes are graph separators or articulation points of the graph defined by the forward pass.

This means, the nodes between a b node and the last checkpoint preceding it need to be recomputed when computing that b node during backprop. When backprop has progressed far enough to reach the checkpoint node, all nodes that were recomputed from it can be erased from memory. The order of computation and memory usage then would appear as:

Graph 3: Checkpointed Backpropagation

Thus, the package implements checkpointed backprop, which is implemented by taking the graph for standard/ vanilla backprop (Graph 1) and automatically rewriting it using the Tensorflow graph editor. For graphs that contain articulation points or single node graph dividers, checkpoints using the sqrt(n) strategy, giving sqrt(n)memory usage for feed-forward networks are automatically selected. For other general graphs that only contain multi-node graph separators, our implementation of checkpointed backprop still works. But currently, the checkpoints have to be selected manually by the user.

Summing up, the biggest advantage of using gradient checkpointing is that it can save a lot of memory for large neural network models. But, this package has some limitations too, which are listed below.

Limitations of gradient checkpointing:

  • The provided code does all graph manipulation in python before running your model. This slows down the process for large graphs.
  • The current algorithm for automatically selecting checkpoints is purely heuristic and is expected to fail on some models outside of the class that are tested. In such cases manual mode checkpoint selection is preferable.

To know more about gradient checkpointing in detail or to have a further explanation of  computation graphs, memory usage, and gradient computation strategies, Yaroslav Bulatov’s medium post on gradient-checkpointing.


Please enter your comment!
Please enter your name here