jax.grad
oluwatobi adefami / June 2022
grad
JAX provides an API -jax.grad
- for taking the gradients of functions in a similar way defined by vector calculus. Given a function $f(x)$, $\nabla f$ represents the function that computes $f$’s gradient, i.e.
It is a popular transformation used to compute the gradient of a given function – it takes a numerical function written in python and returns a new python function which computes the gradient of the original function.
Using jax.grad(f)
in a similar manner as shown in the equation above provides a function that computes the gradient of a given function, hence jax.grad(f)(x)
computes the gradient of the function $f$ at point $x$
Example; let us define a function that takes an array and returns a sum of squares
import jax
import jax.numpy as jnp
def sum_of_squares(arr):
return jnp.sum(arr**2)
Applying jax.grad to sum_of_squares
will transform it into a different function that returns the gradient of the parameter arr
sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(f"Sum of squares: {sum_of_squares(x)}")
print(f"Gradient of sum of squares: {sum_of_squares_dx(x)}")
#output
>>> Sum of squares: 30.0
Gradient of sum of squares: [2. 4. 6. 8.]
Math check (grad only):
\[x = [1.0, \; 2.0, \; 3.0, \; 4.0]\] \[f(x)=\sum x^2\] \[\frac{d f(x)}{dx}= 2\sum x\] \[= 2([1.0, \; 2.0, \; 3.0, \; 4.0])\] \[= [2.0. \; 4.0,\; 6.0,\: 8.0]\]We can see that the JAX API provides a way to perform computations in a very math-esque way by allowing us work directly with functions compared to other automatic differentiation frameworks like PyTorch and Tensorflow, where the computation of gradients is achieved by using the loss tensor. This method of operation makes it easy to control parameters in a mathematical way.
Note:
jax.grad
works in a similar way as the $\nabla$ given by vector calculus as it only works on functions with a scaler output (a function whose output range is a real-number interval). An otherwise operation would raise an error.
jax.grad
allows us to differentiate with respect to any variable of choice in a very straightforward manner like so
# given 2 variables x, y
def sum_squared_errors(x,y):
return jnp.sum((x-y)**2)
# transform the above function using jax.grad
sum_squared_errors_dx = jax.grad(sum_squared_errors)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
print(f"Gradient of sum of squared errors: {sum_squared_errors_dx(x, y)}")
#output
>>> Gradient of sum of squared errors: [-0.20000005 -0.19999981 -0.19999981 -0.19999981]
Math check:
\[x = [1.0, \; 2.0, \; 3.0, \; 4.0]\] \[y = [1.1,\; 2.1,\; 3.1,\; 4.1]\] \[f(x)=\sum (x-y)^2\] \[\frac{\partial f(x)}{\partial x}= 2\sum(x-y)\] \[= 2([-0.1, \; -0.1, \; -0.1, \; -0.1])\] \[= [-0.2 \; -0.2,\; -0.2,\: -0.2]\]jax.grad
takes an argnums argument that allows for obtaining the gradient of a function with respect to one or more variables, and it returns a tuple of gradients.
jax.grad(sum_squared_errors, argnums=(0, 1))(x, y) # Find gradient wrt both x & y
# output
>>> (DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),
DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))
Using argnums this way simplifies such computation which would have required us to take the gradient of the function with respect to each variable explicitly – An illustration is given below
# take the gradient wrt each variable
grad_x = jax.grad(sum_squared_errors, argnums= 0) # gradient wrt x
grad_y = jax.grad(sum_squared_errors, argnums= 1) # gradient wrt y
Test that the gradient of the function wrt to x and y is equal to the gradient of the function wrt x and y using argnums.
# take the gradient wrt x and y using argnums
grad_xy = jax.grad(sum_squared_errors, argnums=(0, 1)) # gradient wrt x,y
for i in range(len(grad_xy(x,y))):
print(grad_xy(x,y)[i] == (grad_x(x,y), grad_y(x,y))[i])
#output
>>> [ True True True True]
[ True True True True]