Differentiation using JAX#

JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the jax.grad(), jax.jvp() and jax.vjp() JAX functions for computing gradients and forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkward Arrays. Only a subset of Awkward Array operations can be differentiated through, including:

  • ufunc operations like x + y

  • reducers like ak.sum()

  • slices like x[1:]

How to differentiate Awkward Arrays?#

For this notebook (which is evaluated on a CPU), we need to configure JAX to use only the CPU.

import jax
jax.config.update("jax_platform_name", "cpu")

Next, we must call ak.jax.register_and_check() to register Awkward’s JAX integration.

import awkward as ak
ak.jax.register_and_check()

Let’s define a simple function that accepts an Awkward Array.

def reverse_sum(array):
    return ak.sum(array[::-1], axis=0)

We can then create an array with which to evaluate reverse_sum. The backend argument ensures that we build an Awkward Array that is backed by jax.Array (jaxlib.xla_extension.ArrayImpl) buffers, which power JAX’s automatic differentiation and JIT compiling features. However, Awkward Array’s JAX backend does not support JIT compilation on reducers as XLA requires array sizes to not be dependent on data values at compile-time.

array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
reverse_sum(array)
[5.0,
 7.0,
 3.0]
-----------------
backend: jax
nbytes: 24 B
type: 3 * float64

Computing the JVP of reverse_sum requires a tangent vector, which can also be defined as an Awkward Array:

tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))

jax.jvp() returns both the value of reverse_sum evaluated at array:

value_jvp
[5.0,
 7.0,
 3.0]
-----------------
backend: jax
nbytes: 24 B
type: 3 * float64
assert value_jvp.to_list() == reverse_sum(array).to_list()

and the JVP evaluted at array for the given tangent:

jvp_grad
[0.0,
 1.0,
 0.0]
-----------------
backend: jax
nbytes: 24 B
type: 3 * float64

Similarly, VJP of reverse_sum can be computed as:

value_vjp, func_vjp = jax.vjp(reverse_sum, array)

where value_vjp is the function (reverse_sum) evaluated at array (forward pass):

assert value_vjp.to_list() == reverse_sum(array).to_list()

and func_vjp is a function that takes a cotangent vector as an argument and returns the VJP (backward pass):

cotanget = ak.Array([0., 1., 0.], backend="jax")
func_vjp(value_vjp)
(<Array [[5.0, 7.0, 3.0], [], [5.0, 7.0]] type='3 * var * float64'>,)

JAX’s own documentation encourages the user to use jax.numpy instead of the canonical numpy module when operating upon JAX arrays. However, jax.numpy does not understand Awkward Arrays, so for ak.Arrays you should use the normal ak and numpy functions instead.