# Differentiation using JAX#

JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the `jax.jvp()`

and `jax.vjp()`

JAX functions for computing forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkard Arrays. Only a subset of Awkward Array operations can be differentiated through, including:

ufunc operations like

`x + y`

reducers like

`ak.sum()`

slices like

`x[1:]`

## How to differentiate Awkward Arrays?#

Before using JAX on functions which deal with Awkward Arrays we need to configure JAX to use only the CPU

```
import jax
jax.config.update("jax_platform_name", "cpu")
```

Next, we must call `ak.jax.register_and_check()`

to register Awkward’s JAX integration

```
import awkward as ak
ak.jax.register_and_check()
```

Let’s define a simple function that accepts an Awkward Array

```
def reverse_sum(array):
return ak.sum(array[::-1], axis=0)
```

We can then create an array with which to evaluate `reverse_sum`

. The `backend`

argument ensures that we build an Awkward Array that is backed by `jaxlib.xla_extension.DeviceArray`

buffers, which power JAX’s automatic differentiation and JIT compiling features.

```
array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
```

```
reverse_sum(array)
```

[5.0, 7.0, 3.0] ----------------- type: 3 * float32

To compute the JVP of `reverse_sum`

requires a *tangent* vector, which can also be defined as an Awkward Array:

```
tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
```

```
value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))
```

`jax.jvp()`

returns both the value of `reverse_sum`

evaluated at `array`

:

```
value_jvp
```

[5.0, 7.0, 3.0] ----------------- type: 3 * float32

```
assert value_jvp.to_list() == reverse_sum(array).to_list()
```

and the JVP evaluted at `array`

for the given `tangent`

:

```
jvp_grad
```

[0.0, 1.0, 0.0] ----------------- type: 3 * float32

JAX’s own documentation encourages the user to use `jax.numpy`

instead of the canonical `numpy`

module when operating upon JAX arrays. However, `jax.numpy`

does not understand Awkward Arrays, so for `ak.Array`

s you should use the normal `ak`

and `numpy`

functions instead.