medium framework

Vectorize with vmap

Use vmap to vectorize a function that computes the dot product of two vectors, applying it across batches of vector pairs.

Given two 2D tensors A and B of shape (batch, n), compute the dot product of each pair of rows without using an explicit loop.

Input:

  • A: A 2D tensor of shape (batch, n)
  • B: A 2D tensor of shape (batch, n)

Output: A 1D tensor of shape (batch,) where each element is the dot product of the corresponding rows.

API Reference:

  • JAX: jax.vmap(fn)(A, B)
  • PyTorch: torch.vmap(fn)(A, B)

Hints

vmap vectorization jax.vmap torch.vmap dot-product
Detecting runtime...