2 min read
Autograd helps JAX automatically differentiate native Python and Numpy code. It can handle a large subset of Python features such as loops, branches, recursion, and closures. It comes with support for reverse-mode (backpropagation) and forward-mode differentiation, and these two can be composed arbitrarily in any order.
XLA or Accelerated Linear Algebra is a linear algebra compiler used for optimizing TensorFlow computations. To run the NumPy programs on GPUs and TPUs, JAX uses XLA. The library calls are compiled and executed just-in-time. JAX also allows compiling your own Python functions just-in-time into XLA-optimized kernels using a one-function API, jit.
How JAX works?
The basic function of JAX is specializing and translating high-level Python and NumPy functions into a representation that can be transformed and then lifted back into a Python function. It traces Python functions by monitoring all the basic operations applied to its input to produce output and then records these operations and the data-flow between them in a directed acyclic graph (DAG).
For tracing the functions, it wraps primitive operations and when they’re called they add themselves to a list of operations performed along with their inputs and outputs. In order to keep track of the data flow between these primitive operations, the values being tracked are wrapped in the Tracer class instances.
The team is working towards expanding this project and provide support for cloud TPU, multi-GPU, and multi-TPU. In future, it will come with full NumPy coverage and some SciPy coverage, and more. As this is still a research project, we can expect bugs and is not recommended to be used in production.
To read more in detail and contribute to this project, head over to GitHub.