Automatic Differentiation
PyTorch Autograd Explained - In-depth Tutorial - YouTube 2.5. Automatic Differentiation — Dive into Deep Learning documentation
Automatic Differentiation in JAX
JAX is a library designed for high-performance numerical and scientific computing, leveraging the capabilities of modern hardware. It extends NumPy and enables automatic differentiation, allowing for the efficient computation of gradients, which are crucial for optimization problems and machine learning algorithms.
Automatic Differentiation (Autodiff)
Automatic differentiation is a set of techniques to numerically evaluate the derivative of a function specified by a computer program. Autodiff is neither symbolic differentiation nor numerical differentiation (finite differences). It takes advantages of the fact that any computer program, no matter how complicated, executes a sequence of elementary arithmetic operations (additions, multiplications, etc.) and elementary functions (exp, log, sin, etc.). By applying the chain rule repeatedly to these operations, derivatives of any order can be computed automatically, accurately, and efficiently.
Forward and Reverse Mode Autodiff
Autodiff comes in two main flavors:
- Forward mode differentiation: Computes the derivative of one input variable at a time. It is efficient for functions with a small number of inputs and a large number of outputs.
- Reverse mode differentiation: Computes the derivative of one output variable at a time. It is efficient for functions with a large number of inputs and a small number of output variables. Reverse mode is the basis for the backpropagation algorithm used in training neural networks.
Differentiation Graph
A differentiation graph is a graphical representation of the computation of a function that facilitates the computation of derivatives. It maps the sequence of operations needed to compute the function, including the intermediate variables and their dependencies. This graph is used by automatic differentiation algorithms to apply the chain rule efficiently and accurately.
How JAX Implements Autodiff
JAX implements automatic differentiation using both forward and reverse mode, enabling the efficient calculation of gradients, Jacobians, and Hessians. It does this through two primary functions:
grad
: For computing the gradient of a function. It employs reverse-mode differentiation, which is particularly useful for functions where the input dimension is much larger than the output dimension.jax.jvp
andjax.vjp
: For forward-mode differentiation. These functions are used to compute the Jacobian-vector product and the vector-Jacobian product, respectively.
Example Usage
Here's a simple example of how to use JAX to compute the gradient of a function:
import jax
import jax.numpy as jnp
def f(x):
return jnp.sin(x) * jnp.cos(x)
grad_f = jax.grad(f)
print(grad_f(1.0)) # Computes the gradient of f at x = 1.0
This code snippet computes the derivative of the function at the point .