Skip to main content

Automatic Differentiation

PyTorch Autograd Explained - In-depth Tutorial - YouTube 2.5. Automatic Differentiation — Dive into Deep Learning documentation

Automatic Differentiation in JAX

JAX is a library designed for high-performance numerical and scientific computing, leveraging the capabilities of modern hardware. It extends NumPy and enables automatic differentiation, allowing for the efficient computation of gradients, which are crucial for optimization problems and machine learning algorithms.

Automatic Differentiation (Autodiff)

Automatic differentiation is a set of techniques to numerically evaluate the derivative of a function specified by a computer program. Autodiff is neither symbolic differentiation nor numerical differentiation (finite differences). It takes advantages of the fact that any computer program, no matter how complicated, executes a sequence of elementary arithmetic operations (additions, multiplications, etc.) and elementary functions (exp, log, sin, etc.). By applying the chain rule repeatedly to these operations, derivatives of any order can be computed automatically, accurately, and efficiently.

Forward and Reverse Mode Autodiff

Autodiff comes in two main flavors:

  • Forward mode differentiation: Computes the derivative of one input variable at a time. It is efficient for functions with a small number of inputs and a large number of outputs.
  • Reverse mode differentiation: Computes the derivative of one output variable at a time. It is efficient for functions with a large number of inputs and a small number of output variables. Reverse mode is the basis for the backpropagation algorithm used in training neural networks.

Differentiation Graph

A differentiation graph is a graphical representation of the computation of a function that facilitates the computation of derivatives. It maps the sequence of operations needed to compute the function, including the intermediate variables and their dependencies. This graph is used by automatic differentiation algorithms to apply the chain rule efficiently and accurately.

How JAX Implements Autodiff

JAX implements automatic differentiation using both forward and reverse mode, enabling the efficient calculation of gradients, Jacobians, and Hessians. It does this through two primary functions:

  • grad: For computing the gradient of a function. It employs reverse-mode differentiation, which is particularly useful for functions where the input dimension is much larger than the output dimension.
  • jax.jvp and jax.vjp: For forward-mode differentiation. These functions are used to compute the Jacobian-vector product and the vector-Jacobian product, respectively.
Example Usage

Here's a simple example of how to use JAX to compute the gradient of a function:

import jax
import jax.numpy as jnp

def f(x):
return jnp.sin(x) * jnp.cos(x)

grad_f = jax.grad(f)
print(grad_f(1.0)) # Computes the gradient of f at x = 1.0

This code snippet computes the derivative of the function f(x)=sin(x)cos(x)f(x) = \sin(x) \cos(x) at the point x=1.0x = 1.0.