# Intro to Jax

## Imports

Some of the libraries we will use throughout this post are imported below.

import time import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl import seaborn as sns

## Introduction

The Jax Quickstart tutorial states

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

What does this mean? And how does this differ from other deep learning libraries such as torch and tensorflow?

As is standard we will import some jax libraries and functions

import jax from jax import jit, grad, vmap from jax import random import jax.numpy as jnp import jax.scipy as jscp

## Backend: XLA

Jax is basically a compiler for turning python code and vector operations using the XLA compiler to machine instructions for different computer architectures. The standard computer architecture we use is the GPU, but there are others, for example

or other specially created hardware which accelerates operations or
make them more efficient in some way. The **point is that python
is slow and XLA makes this very fast using techniques such as fusing
operations and removing redundant code and operations**.
Personally, this feels like a pretty future-proof way of decoupling how
we specify what we want using e.g. python+jax vs how it is made to run
on hardware, here using XLA. It reminds me of how LSP has solved the
decoupling problem for code editing for editors^{1}. There
seem to be even more specialized hardware being created for e.g.
inference of LLMs (like this
which is one of several LLM inference hardware companies I saw at
NeurIPS 2023) so who
knows what funky architectures will become available in the future.

## What is jax?

Jax is a reimplementation of the older linear algebra and science
stack for python including `numpy` and `scipy`, with
a just-in-time compiler and ways to perform automatic differentiation.
To really hammer this home, jax has reimplemented a subset of both of
these packages which seem pretty feature-complete. The current state of
this API can be found in the docs.

## Jax primitives: `jit`,
`grad`, `vmap`

There are 3 functions which are integral to almost any jax program.

`jit`

The `jit` function takes a large subset of python
together with jax functions and compile it down to XLA-kernels which
are very fast. Below I've done a very quick benchmark of how
`jit` speeds up matrix-matrix multiplication.

def jax_matmul(A, B): A @ B jit_jax_matmul = jit(jax_matmul) import timeit n, p, k = 10**4, 10**4, 10**4 A = jnp.ones((n, p)) B = jnp.ones((p, k)) jit_jax_matmul(A, B) # Trace the jit function once print(f"jax: {timeit.timeit(lambda: jax_matmul(A, B).block_until_ready(), number=10)}") print(f"jax (JIT): {timeit.timeit(lambda: jit_jax_matmul(A, B).block_until_ready(), number=10)}")

jax: 0.37372643800335936 jax (JIT): 0.0003170749987475574

which is about double the speed. The gains are much greater when we jit things which does not have an already efficient implementation (such as a matmul). Additionally, this allows us to speed things up which cannot be done without considerable vectorization effort in numpy or may be outright impossible.

`grad`

The `grad` function takes as input a function \(f\)
mapping to \(\mathbb{R}\) and spits out the gradient of that function
\(\nabla f\). This can be a very natural way of working with
gradients if you are used to the math.

def sum_of_squares(x): return jnp.sum(x**2) sum_of_squares_dx = grad(sum_of_squares)

The function `sum_of_squares_dx` is the mathematical
gradient of `sum_of_squares`. The randomness is handled
explicitly by splitting the state (key), read about it
here.

key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) in_x = jax.random.normal(key, (3, 3)) dx = sum_of_squares_dx(in_x) print(dx) print(dx.shape)

[[-5.2211165 0.06770565 2.1726665 ] [-2.960598 3.0806496 2.125032 ] [ 1.0834967 0.0340456 0.544537 ]] (3, 3)

`vmap`

The function `vmap` allows you to lift a function to a
batched function, **without having to go through
vectorization**. For example, if we wanted to batch the
`sum_of_squares` function we can do this by simply applying
`vmap`

batched_sum_of_squares = vmap(sum_of_squares) x = jax.random.normal(key, (5, 3, 3)) print(batched_sum_of_squares(x)) print(batched_sum_of_squares(x).shape)

[ 7.109205 7.1214614 21.167786 6.137778 4.915494 ] (5,)

This is pretty powerful: often it's easy to specify the function
for a sample \(x\) but harder to vectorize. For a standard neural
network it may be pretty simple, but imagine something like LLMs,
GANs or working with inputs which are not points, e.g. sets.
Additionally, we can use the `in_axes` argument to batch in
according to different input arguments and ignore others.

def multi_matmul(A, B, C): return A @ B @ C # Batch according to first and third input argument, not second vmap_multi_matmul = vmap(multi_matmul, in_axes=(0, None, 0)) l, n, p, d, m = 3, 5, 7, 9, 11 A = jnp.ones((l, n, p)) B = jnp.ones((p, d)) C = jnp.ones((l, d, m)) print(vmap_multi_matmul(A, B, C).shape) # l batches of (n, m) -> (l, n, m)

(3, 5, 11)

### Composition

You can compose all of these functions as you see fit

jit_batched_sum_of_squares_dx = jit(vmap(grad(sum_of_squares))) print(jit_batched_sum_of_squares_dx(x).shape)

(5, 3, 3)

This allows for utlizing the autodiff framework fully.

## Building a neural network from scratch

We'll build an MLP using nothing but jax. We will train this on MNIST. To load the data I'm using the jax-dataloader library.

import jax_dataloader as jdl from torchvision.datasets import MNIST pt_ds = MNIST("/tmp/mnist", download=True, transform=lambda x: np.array(x, np.float32), train=True) train_dataloader = jdl.DataLoader(pt_ds, backend="pytorch", batch_size=128, shuffle=True) pt_ds = MNIST("/tmp/mnist", download=True, transform=lambda x: np.array(x, np.float32), train=False) test_dataloader = jdl.DataLoader(pt_ds, backend="pytorch", batch_size=128, shuffle=True)

The jax library have some helpful functions for building neural networks. Here we create parameters and define a prediction function which given a pytree of parameters and an input outputs the predicted logits. Pytrees is a great thing about jax where it allow us to intuitively and effectively use not only raw arrays but also tree-like structures of by composing lists, tuples and dictionaries with each other and arrays as leaves and map over these as if they were arrays.

### Creating the neural network

from jax.nn import relu from jax.nn.initializers import glorot_normal from jax.scipy.special import logsumexp def create_mlp_weights(num_layers: int, in_dim: int, out_dim: int, hidden_dim: int, key): # Create helper function for generating weights and biases in each layer def create_layer_weights(in_dim, out_dim, key): return { "W": glorot_normal()(key, (in_dim, out_dim)), "b": np.zeros(out_dim) } params = [] key, subkey = jax.random.split(key) # Fill out parameter list with dictionary of layer-weights and biases params.append(create_layer_weights(in_dim, hidden_dim, subkey)) for _ in range(1, num_layers): key, subkey = jax.random.split(key) params.append(create_layer_weights(hidden_dim, hidden_dim, key)) key, subkey = jax.random.split(key) params.append(create_layer_weights(hidden_dim, out_dim, subkey)) return params def predict(params, x): for layer in params[:-1]: x = relu(x @ layer["W"] + layer["b"]) logits = x @ params[-1]["W"] + params[-1]["b"] return logits - logsumexp(logits)

Let's pick some reasonable defaults. We see that all shapes are
correct and we have batched the `predict` function.

num_layers = 3 in_dim = 28 * 28 out_dim = 10 hidden_dim = 128 key = jax.random.PRNGKey(2023) params = create_mlp_weights(num_layers, in_dim, out_dim, hidden_dim, key) print(predict(params, jnp.ones(28 * 28))) batched_predict = vmap(predict, in_axes=(None, 0)) print(batched_predict(params, jnp.ones((4, 28 * 28))).shape) print(len(params)) print(type(params[0]["W"]))

[-3.3419425 -1.4851335 -2.5466485 -3.1445212 -1.8924606 -2.5047162 -2.622343 -2.6072748 -1.5674857 -3.5270252] (4, 10) 4 <class 'jaxlib.xla_extension.ArrayImpl'>

### Training

Now we write the helper functions to train this network. In particular we use the pytree functionality of jax to update the parameters which is a pytree since it's a list of dictionaries of arrays.

import jax.tree_util as tree_util def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) @jit def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) return jnp.mean(predicted_class == target_class) def loss(params, images, targets): preds = batched_predict(params, images) return -jnp.mean(preds * targets) @jit def update(params, x, y, step_size): grads = grad(loss)(params, x, y) return tree_util.tree_map(lambda w, g: w - step_size * g, params, grads) EPOCHS = 10 STEP_SIZE = 10 ** -2 train_acc = [] train_loss = [] test_acc = [] test_loss = [] for epoch in range(EPOCHS): print('Epoch', epoch) for image, output in train_dataloader: image, output = jnp.array(image).reshape(-1, 28 * 28), one_hot(jnp.array(output), 10) train_acc.append(accuracy(params, image, output).item()) train_loss.append(loss(params, image, output).item()) params = update(params, image, output, STEP_SIZE) print(f'Train accuracy: {np.mean(train_acc):.3f}') print(f'Train loss: {np.mean(train_loss):.3f}') _test_acc = [] _test_loss = [] for image, output in test_dataloader: image, output = jnp.array(image).reshape(-1, 28 * 28), one_hot(jnp.array(output), 10) _test_acc.append(accuracy(params, image, output).item()) _test_loss.append(loss(params, image, output).item()) test_acc.append(_test_acc) test_loss.append(_test_loss) print(f'Test accuracy: {np.mean(test_acc):.3f}') print(f'Test loss: {np.mean(test_loss):.3f}')

Epoch 0 Train accuracy: 0.788 Train loss: 0.213 Test accuracy: 0.856 Test loss: 0.073 Epoch 1 Train accuracy: 0.832 Train loss: 0.135 Test accuracy: 0.872 Test loss: 0.062 Epoch 2 Train accuracy: 0.856 Train loss: 0.103 Test accuracy: 0.882 Test loss: 0.055 Epoch 3 Train accuracy: 0.872 Train loss: 0.085 Test accuracy: 0.889 Test loss: 0.051 Epoch 4 Train accuracy: 0.883 Train loss: 0.074 Test accuracy: 0.894 Test loss: 0.048 Epoch 5 Train accuracy: 0.892 Train loss: 0.065 Test accuracy: 0.898 Test loss: 0.045 Epoch 6 Train accuracy: 0.899 Train loss: 0.059 Test accuracy: 0.902 Test loss: 0.043 Epoch 7 Train accuracy: 0.905 Train loss: 0.054 Test accuracy: 0.904 Test loss: 0.042 Epoch 8 Train accuracy: 0.910 Train loss: 0.050 Test accuracy: 0.907 Test loss: 0.040 Epoch 9 Train accuracy: 0.914 Train loss: 0.046 Test accuracy: 0.909 Test loss: 0.039

Finally we plot the learning curves

sns.set_theme("notebook") sns.set_style("ticks") iterations_per_epoch = len(train_dataloader) fig, ax = plt.subplots(2, 1) ax[0].plot(np.array(train_loss), label="train_loss") ax[0].plot((np.arange(len(test_loss)) + 1) * iterations_per_epoch, np.array(test_loss).mean(-1), label="test_loss") ax[0].set_ylim([0.0, 0.1]) ax[0].legend() ax[1].plot(np.array(train_acc), label="train_acc") ax[1].plot((np.arange(len(test_acc)) + 1) * iterations_per_epoch, np.array(test_acc).mean(-1), label="test_acc") ax[1].set_ylim([0.8, 1.0]) ax[1].legend() plt.tight_layout()

^{1}

LSP decoupled the implementation of code editing features by allowing the implementation of a server which editors then used through a frontend. In this way the frontend implementation relies on a consistent API but does not actually have to reimplement the server for every editor.