Differentiation using JAX#

JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the jax.jvp() and jax.vjp() JAX functions for computing forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkard Arrays. Only a subset of Awkward Array operations can be differentiated through, including:

  • ufunc operations like x + y

  • reducers like ak.sum()

  • slices like x[1:]

How to differentiate Awkward Arrays?#

Before using JAX on functions which deal with Awkward Arrays we need to configure JAX to use only the CPU

import jax

jax.config.update("jax_platform_name", "cpu")

Next, we must call ak.jax.register_and_check() to register Awkward’s JAX integration

import awkward as ak

ak.jax.register_and_check()

Let’s define a simple function that accepts an Awkward Array

def reverse_sum(array):
    return ak.sum(array[::-1], axis=0)

We can then create an array with which to evaluate reverse_sum. The backend argument ensures that we build an Awkward Array that is backed by jaxlib.xla_extension.DeviceArray buffers, which power JAX’s automatic differentiation and JIT compiling features.

array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
reverse_sum(array)
[5.0,
 7.0,
 3.0]
-----------------
type: 3 * float32

To compute the JVP of reverse_sum requires a tangent vector, which can also be defined as an Awkward Array:

tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))

jax.jvp() returns both the value of reverse_sum evaluated at array:

value_jvp
[5.0,
 7.0,
 3.0]
-----------------
type: 3 * float32
assert value_jvp.to_list() == reverse_sum(array).to_list()

and the JVP evaluted at array for the given tangent:

jvp_grad
[0.0,
 1.0,
 0.0]
-----------------
type: 3 * float32

JAX’s own documentation encourages the user to use jax.numpy instead of the canonical numpy module when operating upon JAX arrays. However, jax.numpy does not understand Awkward Arrays, so for ak.Arrays you should use the normal ak and numpy functions instead.