Welcome to JAX 201! After mastering transforms like jit, grad, and vmap, the next jump is harnessing multiple accelerators at once. This guide shows how to turn a single-device notebook into a synchronous data-parallel trainer that runs across all GPUs on your machine.

You will learn how to (1) confirm what hardware JAX sees, (2) reshape batches so each device gets a shard, (3) write a pmap-powered training step that keeps replicas in sync, and (4) debug common multi-GPU gotchas without breaking out of your notebook.

0. Notebook setup

Run the same Colab or workstation environment from JAX 101–103, but this time install CUDA-enabled wheels so every GPU is visible to JAX. Verify jaxlib is built with CUDA support before continuing.

%pip install -q -U "jax[cuda12_pip]" jaxopt optax flax einops

Restart the runtime after the install, then double-check device visibility from Python.

import jax

print("Backend:", jax.default_backend())
print("Total devices:", jax.device_count())
print("Local devices:", jax.local_device_count())
print("Device kinds:", {device.device_kind for device in jax.devices()})
Backend: gpu
Total devices: 4
Local devices: 4
Device kinds: {'NVIDIA A100-SXM4-40GB'}

If the counts come back as 1, make sure your Colab runtime is set to multi-GPU or that your workstation drivers expose every GPU to CUDA. JAX 201 assumes you have at least two GPUs.

1. The data-parallel mental model

jax.pmap replicates your function across devices. Each replica:

  • Receives a shard of your batch (e.g., batch size 256 becomes four shards of 64 on four GPUs).
  • Runs the same computation independently.
  • Participates in collective operations (like mean-reducing gradients) via an axis_name.

Because pmap compiles a single program and launches it everywhere, you should design your training step as if it runs on one shard. Shared parameters are replicated automatically; per-device inputs flow in through the leading axis.

import jax.numpy as jnp

per_device_batch_size = 64
global_batch_size = per_device_batch_size * jax.local_device_count()

Keeping per_device_batch_size constant makes experimentation predictable when you scale the number of GPUs.

2. Sharding batches across devices

Shape your input data as [num_devices, per_device_batch, ...] before feeding it to your pmap-ed function. That leading dimension is how pmap knows which slice belongs on which device. jax.device_put_sharded and with_sharding_constraint can enforce placement, but for most notebook workflows you simply reshape the host array once and let pmap scatter it automatically.

def prepare_batch(batch):
    images, labels = batch
    images = images.reshape((jax.local_device_count(), -1, *images.shape[1:]))
    labels = labels.reshape((jax.local_device_count(), -1))
    return images, labels

When using data loaders like TensorFlow Datasets or PyTorch DataLoader, set the global batch size upfront so reshaping simply splits the first dimension.

Two helpful invariants:

  • global_batch_size == per_device_batch_size * jax.local_device_count()
  • images.shape[0] == labels.shape[0] == jax.local_device_count()

If your loader yields a PyTree (e.g., (images, {"label": labels, "mask": mask})), apply the same reshape to every leaf with a tree_map.

from jax import tree_util

def shard(tree):
    return tree_util.tree_map(
        lambda x: x.reshape((jax.local_device_count(), -1, *x.shape[1:])),
        tree,
    )

To keep data on-device between steps, move each shard explicitly.

def to_device(tree):
    shards = tree_util.tree_map(
        lambda x: [x[i] for i in range(jax.local_device_count())],
        tree,
    )
    return tree_util.tree_map(
        lambda xs: jax.device_put_sharded(xs, jax.local_devices()),
        shards,
    )

Combine shard and to_device if you notice repeated host↔device transfers in your profiler. The end result should be a tree where the outermost axis matches the device count and the inner axes look exactly like the single-device version of your model.

3. Replicating parameters

Parameters should live on every device. jax.device_put_replicated copies a tree of arrays so each GPU starts with identical weights.

from flax.training import train_state

def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones([1, 28, 28, 1]))["params"]
    tx = optax.adamw(learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return jax.device_put_replicated(state, jax.local_devices())

The returned state is a ShardedDeviceArray with one replica per GPU. Updates inside pmap produce the same structure, keeping every copy aligned.

4. Writing the pmap training step

Wrap your scalar training step with jax.pmap. Give it an axis_name so you can synchronize gradients with lax.pmean.

from functools import partial

@partial(jax.pmap, axis_name="batch")
def train_step(state, batch, rng):
    images, labels = batch

    def loss_fn(params):
        logits = state.apply_fn({"params": params}, images, rngs={"dropout": rng})
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    grads = jax.lax.pmean(grads, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    metrics = {
        "loss": jax.lax.pmean(loss, axis_name="batch"),
        "accuracy": jax.lax.pmean((jnp.argmax(logits, axis=-1) == labels).mean(), axis_name="batch"),
    }
    return new_state, metrics

Key details:

  • batch is a tuple (images, labels) already reshaped to [num_devices, per_device_batch, ...].
  • rng must be per-device. Split outside the pmap call so each replica sees a unique key.
  • lax.pmean averages gradients and metrics across replicas, ensuring synchronous SGD.

5. Driving the training loop

Gather randomness, shard batches, and keep everything on-device to avoid host-device churn.

def train_epoch(state, train_loader, rng):
    metrics_accumulator = []
    rngs = jax.random.split(rng, jax.local_device_count())

    for batch in train_loader:
        shard = prepare_batch(batch)
        state, metrics = train_step(state, shard, rngs)
        metrics_accumulator.append(jax.device_get(metrics))
        rngs = jax.random.split(rngs, 2)[:, 0]  # refresh per-device keys

    stacked = {k: jnp.stack([m[k] for m in metrics_accumulator]) for k in metrics_accumulator[0]}
    return state, {k: float(v.mean()) for k, v in stacked.items()}

jax.device_get transfers lightweight metric snapshots back to the host; parameters and optimizer state stay sharded.

6. Debugging multi-GPU runs

  • Shape mismatches: Watch for [global_batch, ...] tensors sneaking into pmap. Adding assert shard.shape[0] == jax.local_device_count() early in train_step saves time.
  • Silent recompiles: Changing Python-side control flow per step forces recompilation on every device. Keep the traced function static and branch with lax.cond.
  • Divergent replicas: If metric variance spikes, print per-device values via jax.debug.print("loss shard {}", loss, ordered=True) inside the mapped function.
  • Out-of-memory: Start with smaller per_device_batch_size. Multi-GPU often tempts you to double the global batch, but each GPU still has finite memory.

7. Beyond pmap

  • pjit (a.k.a. GSPMD) gives finer-grained array sharding when model parallelism matters.
  • jax.sharding.Mesh combined with NamedSharding lets you mix data and model parallel axes for very large models.
  • Libraries like Orbax Checkpointing and T5X provide production-ready wrappers once you need multi-host scaling.

As you add more GPUs or TPU chips, keep reasoning in terms of per-device work plus collectives. With the patterns in this guide, you can scale the same training loop from a single GPU notebook all the way to multi-host clusters.