Welcome to JAX 103! After getting comfortable with jit, grad, and lax control flow, the next milestone is vectorization. JAX encourages you to write scalar or single-example functions and then lift them across batch dimensions without rewriting loops. jax.vmap is the workhorse transform that achieves this, delivering Python-free batching that runs efficiently on accelerators and stays compatible with autodiff.

In this guide you will: (1) understand what vmap actually stages under the hood, (2) learn how to control the axes you map over, (3) combine vmap with grad and Jacobian utilities, and (4) see how to pair it with other primitives like scan and pmap when you move from single-device to multi-device workloads.

0. Notebook setup

Same environment as JAX 101/102. The only new imports are jax.vmap, jax.jacfwd, and jax.jacrev so we can peek at Jacobian construction.

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, jacfwd, jacrev
import optax

Throughout the notebook, we wrap the final functions with jit so you can see how compilation interacts with vectorization. Remember that vmap itself is a pure transformation: it rewrites your function, but the result is still a regular JAX function that you can stage, differentiate, or nest again.

1. Why vmap exists

Python loops are eager and sequential. When you iterate over a batch dimension in pure Python, you constantly hand control back to the interpreter, which blocks XLA from seeing the full computation graph. vmap solves this by tracing the scalar function once and applying it to each element of the batch inside the compiled graph.

Consider a simple function that computes logits and probabilities for one example:

def predict(params, features):
    logits = features @ params["w"] + params["b"]
    return jnp.array([logits, jax.nn.sigmoid(logits)])

Vectorizing over a mini-batch requires only one call:

batched_predict = vmap(predict, in_axes=(None, 0))

Here in_axes=(None, 0) means params is shared across the batch (None) while features consumes its leading axis (0). The output automatically stacks along a new leading dimension.

params = {"w": jnp.array([1.5, -0.5]), "b": 0.25}
X = jnp.array([[1.0, 2.0], [2.5, -1.0], [-0.25, 0.75]])
logits_probs = batched_predict(params, X)
print("shape:", logits_probs.shape)
print("first row:", logits_probs[0])
shape: (3, 2)
first row: [0.75       0.6791787]

If you want to see the entire stacked result, a quick print(logits_probs) shows the logits alongside their sigmoid activations:

[[ 0.75        0.6791787 ]
 [ 4.5         0.9890131 ]
 [-0.5         0.37754068]]

Everything is compiled into a single kernel—no Python loop required. You could now jit(batched_predict) or differentiate through it exactly as if you had written a bespoke batched function.

2. Controlling input and output axes

in_axes tells vmap which dimension to iterate over for each argument. out_axes makes the same decision for outputs. The defaults (in_axes=0, out_axes=0) map over the leading dimension, but real datasets often need more flexibility.

If “axis” language feels abstract, picture each tensor as a stack of slices. in_axes=0 means “peel off slices along the first dimension and feed them one-by-one into the scalar function.” Marking an argument with None says “keep this whole object constant for every slice.” Likewise, out_axes answers “where should JAX restack the results in the returned array?” Keeping these mental pictures handy makes the next patterns easier to reason about.

Broadcasting only one argument

When broadcasting a scalar hyperparameter across batch elements, mark that axis as None. Read it as “do not slice this argument; reuse it for every mapped call.” The same intuition applies to PyTrees (lists, tuples, dicts): supply an in_axes tree that mirrors the argument structure, filling it with either an integer axis to slice or None to keep the entire branch constant.

def scaled_loss(scale, preds, targets):
    residual = preds - targets
    return scale * jnp.sum(residual**2)

batched_loss = vmap(scaled_loss, in_axes=(None, 0, 0))

scale = 0.5
preds = jnp.array([[0.2, 0.8], [0.6, 0.4]])
targets = jnp.array([[0.0, 1.0], [1.0, 0.0]])
print(batched_loss(scale, preds, targets))
[0.04 0.16]

Moving the batch axis to the end

out_axes lets you keep channel-first layouts or align with library expectations. A negative index (-1) is a friendly reminder that you can restack the results wherever they need to land.

Start by reading the sample code below so the axis juggling is clear before you execute it; we construct two random “images” so you can track how the dimensions move.

def normalize(x):
    x_centered = x - jnp.mean(x)
    return x_centered / jnp.std(x)

images = jax.random.normal(jax.random.PRNGKey(0), (2, 28, 28))
print("images shape:", images.shape)

channel_first_norm = vmap(normalize, in_axes=0, out_axes=0)(images)
print("channel_first_norm shape:", channel_first_norm.shape)

channel_last_norm = vmap(normalize, in_axes=0, out_axes=-1)(images)
print("channel_last_norm shape:", channel_last_norm.shape)
images shape: (2, 28, 28)
channel_first_norm shape: (2, 28, 28)
channel_last_norm shape: (28, 28, 2)

The computation still runs per image; we only reshaped the stacked result to satisfy a downstream API that expects the batch dimension at the end.

3. Vectorizing functions with multiple batch dimensions

vmap can iterate over multiple axes simultaneously. Suppose we want pairwise Euclidean distances between each point in X and Y. We can nest two vmap calls, each over a different operand.

def euclidean_distance(x, y):
    return jnp.sqrt(jnp.sum((x - y) ** 2))

pairwise_dist = vmap(vmap(euclidean_distance, in_axes=(None, 0)), in_axes=(0, None))

X = jnp.array([[0., 0.], [1., 0.], [0., 1.]])
Y = jnp.array([[0., 0.], [1., 1.]])
print(pairwise_dist(X, Y))
[[0.        1.4142135]
 [1.        1.        ]
 [1.        1.        ]]

The outer vmap sweeps over rows of X, the inner one sweeps over rows of Y, and JAX stages the entire nested mapping into a fused kernel. This pattern scales to more complex structures such as attention scores or kernel matrices.

4. Differentiation across batches

You can differentiate through vmap exactly as you would any other JAX function. Two common patterns:

  • Compute gradients per example (e.g., for per-sample losses or influence functions).
  • Assemble Jacobians without writing loops.

Per-example gradients

def loss_per_example(params, x, y):
    logits = x @ params["w"] + params["b"]
    return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, y))

grad_per_example = vmap(grad(loss_per_example), in_axes=(None, 0, 0))

params = {"w": jnp.array([1.0, -2.0]), "b": 0.3}
features = jnp.array([[1.0, 0.0], [0.5, 1.5], [2.0, -1.0]])
targets = jnp.array([[1., 0., 1.], [0., 1., 0.], [1., 0., 1.]])
grads = grad_per_example(params, features, targets)
print("dw shape:", grads["w"].shape)
print("db shape:", grads["b"].shape)
dw shape: (3, 2)
db shape: (3,)

Every example gets its own gradient, which you can later aggregate (jnp.mean, clipping, etc.) without leaving compiled JAX land.

Jacobians via jacfwd and jacrev

jacfwd and jacrev are built from vmap internally. You can also construct expert-level Jacobians manually by wrapping directional derivatives in vmap, which gives you precise control over memory layout.

def softmax_logits(theta):
    return jax.nn.softmax(theta)

theta = jnp.array([1.0, 0.0, -1.0])
jacobian = jacrev(softmax_logits)(theta)
print(jacobian)
[[ 0.22269541 -0.1628034  -0.05989202]
 [-0.1628034   0.18483643 -0.02203304]
 [-0.05989202 -0.02203305  0.08192506]]

Behind the scenes, jacrev applies reverse-mode autodiff to each standard basis vector via vmap. Knowing this helps when you need custom Jacobian-vector products or batched Hessians—you can always break them down into vmap + grad combinations.

5. Randomness and vmap

Because JAX requires explicit PRNG handling, vectorizing stochastic functions means threading unique keys through the batch. The canonical pattern is to split once per example.

def dropout_layer(key, x, rate=0.1):
    keep = jax.random.bernoulli(key, p=1.0 - rate, shape=x.shape)
    return keep * x / (1.0 - rate)

def batched_dropout(key, inputs):
    keys = jax.random.split(key, inputs.shape[0])
    return vmap(dropout_layer)(keys, inputs)

key = jax.random.PRNGKey(42)
activations = jnp.ones((4, 5))
print(batched_dropout(key, activations))
[[1.1111112 1.1111112 0.        1.1111112 1.1111112]
 [1.1111112 1.1111112 1.1111112 1.1111112 1.1111112]
 [1.1111112 1.1111112 1.1111112 1.1111112 1.1111112]
 [1.1111112 1.1111112 1.1111112 1.1111112 1.1111112]]

vmap zipped over (keys, inputs) because both arguments default to in_axes=0. If you need to share a key across all rows (e.g., for deterministic evaluation), set in_axes=(None, 0) instead.

6. vmap versus scan versus pmap

These primitives solve different problems:

  • vmap vectorizes a function across a batch dimension on a single device.
  • scan sequences a computation with carried state (e.g., RNN time steps).
  • pmap shards both data and computation across multiple devices.

They compose nicely. For instance, you can run a time-major RNN on each example via scan, then apply that RNN to each element of a batch via vmap, and finally distribute the batch across devices with pmap.

def rnn_cell(carry, x_t, params):
    h_prev = carry
    w_hh, w_xh, b = params
    h = jnp.tanh(h_prev @ w_hh + x_t @ w_xh + b)
    return h, h

def run_sequence(inputs, params):
    h0 = jnp.zeros(params[0].shape[0])
    _, outputs = jax.lax.scan(lambda c, x: rnn_cell(c, x, params), h0, inputs)
    return outputs

batched_run_sequence = vmap(run_sequence, in_axes=(0, None))
jitted_batched_run = jit(batched_run_sequence)

You can now feed a batch of sequences into jitted_batched_run and still take gradients with respect to parameters. If you later need to scale to multiple GPUs or TPUs, wrap batched_run_sequence in pmap without rewriting the recurrence.

7. Performance checklist

To make the most of vmap:

  • Keep mapped functions pure and side-effect-free. Any host I/O inside the mapped function will execute repeatedly.
  • Avoid data-dependent Python control flow inside the mapped body; use lax.cond or lax.switch instead.
  • Combine vmap with jit for heavy workloads: jit(vmap(...)) compiles the vectorized function once, whereas vmap(jit(...)) recompiles per element.
  • Profile memory: vmap materializes stacked outputs. If an intermediate array is huge, consider reducing within the mapped body or switching to scan when you only need aggregated results.

8. Where to go next

  • Replace your manual Python loops in research notebooks with vmap to unlock better accelerator utilization.
  • Explore jax.vmap nesting with axis_name annotations once you adopt pjit or xmap.
  • Read how libraries like Flax or Equinox supply batched modules—you will recognize the same vmap patterns applied to parameter PyTrees.

The leap from jit and grad to vmap is what makes JAX feel composable. You now have the tools to write scalar-focused APIs that scale across data, time, and devices with a handful of decorators.