Differentiation using JAX#
JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the jax.grad()
, jax.jvp()
and jax.vjp()
JAX functions for computing gradients and forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkward Arrays. Only a subset of Awkward Array operations can be differentiated through, including:
ufunc operations like
x + y
reducers like
ak.sum()
slices like
x[1:]
How to differentiate Awkward Arrays?#
For this notebook (which is evaluated on a CPU), we need to configure JAX to use only the CPU.
import jax
jax.config.update("jax_platform_name", "cpu")
Next, we must call ak.jax.register_and_check()
to register Awkward’s JAX integration.
import awkward as ak
ak.jax.register_and_check()
Let’s define a simple function that accepts an Awkward Array.
def reverse_sum(array):
return ak.sum(array[::-1], axis=0)
We can then create an array with which to evaluate reverse_sum
. The backend
argument ensures that we build an Awkward Array that is backed by jax.Array
(jaxlib.xla_extension.ArrayImpl
) buffers, which power JAX’s automatic differentiation and JIT compiling features. However, Awkward Array’s JAX backend does not support JIT compilation on reducers as XLA requires array sizes to not be dependent on data values at compile-time.
array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
reverse_sum(array)
[5.0, 7.0, 3.0] ----------------- backend: jax nbytes: 24 B type: 3 * float64
Computing the JVP of reverse_sum
requires a tangent vector, which can also be defined as an Awkward Array:
tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))
jax.jvp()
returns both the value of reverse_sum
evaluated at array
:
value_jvp
[5.0, 7.0, 3.0] ----------------- backend: jax nbytes: 24 B type: 3 * float64
assert value_jvp.to_list() == reverse_sum(array).to_list()
and the JVP evaluted at array
for the given tangent
:
jvp_grad
[0.0, 1.0, 0.0] ----------------- backend: jax nbytes: 24 B type: 3 * float64
Similarly, VJP of reverse_sum
can be computed as:
value_vjp, func_vjp = jax.vjp(reverse_sum, array)
where value_vjp
is the function (reverse_sum
) evaluated at array
(forward pass):
assert value_vjp.to_list() == reverse_sum(array).to_list()
and func_vjp
is a function that takes a cotangent vector as an argument and returns the VJP (backward pass):
cotanget = ak.Array([0., 1., 0.], backend="jax")
func_vjp(value_vjp)
(<Array [[5.0, 7.0, 3.0], [], [5.0, 7.0]] type='3 * var * float64'>,)
JAX’s own documentation encourages the user to use jax.numpy
instead of the canonical numpy
module when operating upon JAX arrays. However, jax.numpy
does not understand Awkward Arrays, so for ak.Array
s you should use the normal ak
and numpy
functions instead.