Welcome to JAX 101! If you are just starting with differentiable programming, think of this guide as a friendly walking tour. We will slow down at each turn, explain what the tooling does, and show the exact notebook outputs so you can compare your run with ours.

The goal is simple: open a fresh Google Colab notebook, install the JAX stack, and train a logistic regression model using only JAX and Optax. Along the way we keep the NumPy feel while gently introducing JAX concepts like jit, grad, and vmap.

Everything below is copy-paste ready. After you run each cell, check the output blocks—we captured them from a clean Colab session so you know what “success” looks like before moving on.

JAX in equations

For a binary classification dataset $(X, y)$ with $X \in \mathbb{R}^{N \times K}$ and $y \in {0, 1}^N$, a logistic regression loss under JAX-friendly notation is

$$ \mathcal{L}(w, b) = -\frac{1}{N} \sum_{n=1}^N \left[ y_n \log \sigma(\eta_n) + (1 - y_n) \log (1 - \sigma(\eta_n)) \right], $$

$$ \eta_n = x_n^\top w + b, $$

and gradient descent updates follow

$$ \theta^{(t+1)} = \theta^{(t)} - \eta \nabla_\theta \mathcal{L}(\theta^{(t)}), $$

where $\theta = (w, b)$ and $\sigma(\cdot)$ is the logistic sigmoid. JAX makes these updates automatic with jax.grad, while jax.jit stages the computation graph so Colab’s accelerators can execute it efficiently.

Why JAX for differentiable computing

  • Keep NumPy syntax while gaining auto-diff, just-in-time (JIT) compilation, and accelerator support.
  • Compose transformations (grad, jit, vmap, pmap) instead of rewriting kernels when requirements change.
  • Reproducible random number workflow via explicit PRNGKey handling.
  • Integrate with the broader ecosystem: Optax optimizers, Flax modules, Equinox models, or Haiku networks without switching paradigms.
  • Prototype in notebooks, then copy-paste into production code—no hidden notebook magic after the %pip installs.

1. Launch a Colab runtime

Open a fresh Google Colab notebook, set the runtime type you need (Runtime → Change runtime type → GPU for CUDA, or leave it on CPU), then install the packages we rely on. Restart the runtime after the install so JIT picks up the correct jaxlib.

%pip install -q -U "jax[cpu]" optax flax einops matplotlib

If you enabled a GPU runtime, a quick nvidia-smi check confirms Colab handed you a T4 accelerator:

!nvidia-smi
Tue Oct 28 21:15:40 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
| Processes:                                                                              |
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Verify the backend and available devices:

import jax
import jax.numpy as jnp

print("JAX version:", jax.__version__)
print("Backend:", jax.default_backend())
print("Devices:", jax.devices())
JAX version: 0.7.2
Backend: gpu
Devices: [CudaDevice(id=0)]

If you switched the runtime to GPU, replace "jax[cpu]" with the wheel that matches the CUDA version (e.g., "jax[cuda12]" plus the appropriate wheel URL from the JAX release table).

2. Explore JAX arrays and transformations

Colab ships with regular NumPy, but we’ll operate on jax.numpy arrays so gradients and JIT compilation behave. The snippet below shows broadcasting, grad, and jit in one go.

from jax import grad, jit

x = jnp.linspace(-2.0, 2.0, 5)
tau = 0.5

def tempered_softmax(x, temperature):
    shifted = x - jnp.max(x)
    weights = jnp.exp(shifted / temperature)
    return weights / jnp.sum(weights)

softmax_val = tempered_softmax(x, tau)
dsoftmax_dtau = grad(lambda t: tempered_softmax(x, t).sum())(tau)
print("Softmax:", softmax_val)
print("d/dtau sum =", dsoftmax_dtau)

compiled_softmax = jit(tempered_softmax)
_ = compiled_softmax(x, tau)  # first call compiles
Softmax: [2.9007587e-04 2.1433870e-03 1.5837606e-02 1.1702495e-01 8.6470395e-01]
d/dtau sum = -4.3092886e-08

tempered_softmax never mutates arrays in-place, which keeps JAX’s functional semantics intact. The grad transformation differentiates through the entire calculation, and jit stages it for XLA once the function signature stabilizes.

grad always expects a scalar target, so we wrap tempered_softmax with .sum() to collapse the probability vector into a scalar log-partition surrogate before taking $\frac{\partial}{\partial \tau}$. If you expand the expression,

$$ \frac{\partial}{\partial \tau} \sum_k \frac{\exp((x_k - \max(x)) / \tau)}{\sum_j \exp((x_j - \max(x)) / \tau)} $$

the numerator and denominator share the same temperature dependence, which means the sum stays identically equal to $1$ for any $\tau$. The near-zero result reported above is floating-point noise that confirms JAX took the derivative and found the expected invariant. In other contexts you could replace .sum() with any scalar functional (e.g., an entropy term) to probe how temperature reshapes the softmax distribution.

3. Simulate a logistic regression dataset

We’ll synthesize a dataset so you can rerun the same notebook from scratch. Managing randomness via jax.random keys keeps our splits reproducible.

def make_dataset(key, n_samples=512, scale=1.0):
    key_x, key_noise = jax.random.split(key)
    X = scale * jax.random.normal(key_x, (n_samples, 2))
    true_w = jnp.array([2.0, -1.0])
    true_b = -0.8
    logits = X @ true_w + true_b
    probs = jax.nn.sigmoid(logits)
    y = jax.random.bernoulli(key_noise, probs).astype(jnp.float32)
    return X, y

key = jax.random.PRNGKey(0)
X_train, y_train = make_dataset(key)
print("Shapes:", X_train.shape, y_train.shape)
Shapes: (512, 2) (512,)

Because the dataset is small, we can train on the full batch. For larger corpora, jax.numpy slicing keeps mini-batches easy to manage.

4. Optimize with Optax

Optax supplies familiar optimizers with JAX-native updates. We’ll define a minimal parameter dictionary, compute the loss, and run the update step under jit. Everything here runs in a single Colab cell.

import optax

def model(params, X):
    return X @ params["w"] + params["b"]

def loss_fn(params, X, y):
    logits = model(params, X)
    losses = optax.sigmoid_binary_cross_entropy(logits, y)
    return losses.mean()

params = {"w": jnp.zeros((2,)), "b": 0.0}
optimizer = optax.adam(learning_rate=3e-2)
opt_state = optimizer.init(params)

@jit
def update(params, opt_state, X, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    probs = jax.nn.sigmoid(model(params, X))
    accuracy = jnp.mean((probs > 0.5) == y)
    return params, opt_state, loss, accuracy

for step in range(1, 501):
    params, opt_state, loss, acc = update(params, opt_state, X_train, y_train)
    if step % 100 == 0:
        print(f"step {step:03d} loss={loss:.4f} acc={acc:.3f}")
step 100 loss=0.4161 acc=0.805
step 200 loss=0.4136 acc=0.805
step 300 loss=0.4136 acc=0.805
step 400 loss=0.4136 acc=0.805
step 500 loss=0.4136 acc=0.805

On CPU the loop completes in a few seconds. On GPU the JIT-compiled update stays the same; the speedup comes from XLA lowering the linear algebra to accelerator kernels.

5. Vectorize predictions with vmap

vmap vectorizes scalar functions without writing explicit loops. Here we turn a single-example prediction into a batched predictor that feeds evaluation metrics or posterior sampling later on.

from functools import partial

def predict_single(params, x):
    return jax.nn.sigmoid(jnp.dot(x, params["w"]) + params["b"])

batched_predict = jax.vmap(partial(predict_single, params), in_axes=(0,))
probs = batched_predict(X_train)
roc_points = jnp.stack([y_train, probs], axis=1)
print("First five probability pairs:\n", roc_points[:5])
First five probability pairs:
[[1.         0.9177734 ]
[1.         0.32183945]
[0.         0.04654315]
[1.         0.88042223]
[1.         0.992222  ]]

functools.partial freezes params into the function signature, so vmap now sees a single argument—the batch of feature vectors—and maps axis 0 across it. The effect is identical to calling jax.vmap(predict_single, in_axes=(None, 0))(params, X_train), but partial keeps the call site tidy and avoids repeating the static-axis tuple every time you reuse the helper. When you later wrap batched_predict in jax.jit, those closed-over parameters are treated as static inputs, letting XLA specialize the compiled graph while still reusing it for different batches. Reach for the same pattern when you need per-example losses, gradient inspection, or ensemble voting without writing explicit Python loops.

6. Profile and debug with structured logging

Once the hot path stabilizes, wrap helper functions with jax.jit and use jax.debug.print to surface intermediate values without breaking compilation.

@jax.jit
def evaluate(params, X, y):
    logits = model(params, X)
    preds = (jax.nn.sigmoid(logits) > 0.5).astype(jnp.float32)
    accuracy = jnp.mean(preds == y)
    return accuracy

acc = evaluate(params, X_train, y_train)
jax.debug.print("Training accuracy: {acc:.3f}", acc=acc)
Training accuracy: 0.805

Colab’s %%timeit magic works with JAX functions too—just remember to call the JIT-ed function once before benchmarking so the compilation cost is paid upfront.

7. Takeaways

  • JAX stays NumPy-like while unlocking auto-diff, vectorization, and compilation in one toolbox.
  • Google Colab handles the heavy lifting—%pip install for the runtime, then jit, grad, and vmap compose naturally.
  • Optax supplies optimizer updates that slot directly into JAX training loops, so building logistic regression (or deeper models) is a few dozen lines.
  • One 500-step run hit roughly 81% accuracy, so you can trust the training loop before exploring bigger models or new features.
  • Explicit PRNGKey management and functional programming patterns keep notebook experiments reproducible.
  • The same pattern scales: swap the linear model for a Flax module, vectorize over thousands of data points, and lean on XLA to run it fast on TPUs or GPUs.