# Differentiation using JAX#

JAX, amongst other things, is a powerful tool for computing derivatives of native Python and NumPy code. Awkward Array implements support for the `jax.jvp()` and `jax.vjp()` JAX functions for computing forward/reverse-mode Jacobian-vector/vector-Jacobian products of functions that operate upon Awkard Arrays. Only a subset of Awkward Array operations can be differentiated through, including:

• ufunc operations like `x + y`

• reducers like `ak.sum()`

• slices like `x[1:]`

## How to differentiate Awkward Arrays?#

Before using JAX on functions which deal with Awkward Arrays we need to configure JAX to use only the CPU

```import jax

jax.config.update("jax_platform_name", "cpu")
```

Next, we must call `ak.jax.register_and_check()` to register Awkward’s JAX integration

```import awkward as ak

ak.jax.register_and_check()
```

Let’s define a simple function that accepts an Awkward Array

```def reverse_sum(array):
return ak.sum(array[::-1], axis=0)
```

We can then create an array with which to evaluate `reverse_sum`. The `backend` argument ensures that we build an Awkward Array that is backed by `jaxlib.xla_extension.DeviceArray` buffers, which power JAX’s automatic differentiation and JIT compiling features.

```array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax")
```
```reverse_sum(array)
```
```[5.0,
7.0,
3.0]
-----------------
type: 3 * float32```

To compute the JVP of `reverse_sum` requires a tangent vector, which can also be defined as an Awkward Array:

```tangent = ak.Array([[0.0, 0.0, 0.0], [], [0.0, 1.0]], backend="jax")
```
```value_jvp, jvp_grad = jax.jvp(reverse_sum, (array,), (tangent,))
```

`jax.jvp()` returns both the value of `reverse_sum` evaluated at `array`:

```value_jvp
```
```[5.0,
7.0,
3.0]
-----------------
type: 3 * float32```
```assert value_jvp.to_list() == reverse_sum(array).to_list()
```

and the JVP evaluted at `array` for the given `tangent`:

```jvp_grad
```
```[0.0,
1.0,
0.0]
-----------------
type: 3 * float32```

JAX’s own documentation encourages the user to use `jax.numpy` instead of the canonical `numpy` module when operating upon JAX arrays. However, `jax.numpy` does not understand Awkward Arrays, so for `ak.Array`s you should use the normal `ak` and `numpy` functions instead.