Differentiation using JAX#
JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the jax.jvp()
and jax.vjp()
JAX functions for computing forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkard Arrays. Only a subset of Awkward Array operations can be differentiated through, including:
ufunc operations like
x + y
reducers like
ak.sum()
slices like
x[1:]
How to differentiate Awkward Arrays?#
Before using JAX on functions which deal with Awkward Arrays we need to configure JAX to use only the CPU
import jax
jax.config.update("jax_platform_name", "cpu")
Next, we must call ak.jax.register_and_check()
to register Awkward’s JAX integration
import awkward as ak
ak.jax.register_and_check()
Let’s define a simple function that accepts an Awkward Array
def reverse_sum(array):
return ak.sum(array[::-1], axis=0)
We can then create an array with which to evaluate reverse_sum
. The backend
argument ensures that we build an Awkward Array that is backed by jaxlib.xla_extension.DeviceArray
buffers, which power JAX’s automatic differentiation and JIT compiling features.
array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
reverse_sum(array)
[5.0, 7.0, 3.0] ----------------- type: 3 * float32
To compute the JVP of reverse_sum
requires a tangent vector, which can also be defined as an Awkward Array:
tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))
jax.jvp()
returns both the value of reverse_sum
evaluated at array
:
value_jvp
[5.0, 7.0, 3.0] ----------------- type: 3 * float32
assert value_jvp.to_list() == reverse_sum(array).to_list()
and the JVP evaluted at array
for the given tangent
:
jvp_grad
[0.0, 1.0, 0.0] ----------------- type: 3 * float32
JAX’s own documentation encourages the user to use jax.numpy
instead of the canonical numpy
module when operating upon JAX arrays. However, jax.numpy
does not understand Awkward Arrays, so for ak.Array
s you should use the normal ak
and numpy
functions instead.