Use vmap to vectorize a function that computes the dot product of two vectors,
applying it across batches of vector pairs.
Given two 2D tensors A and B of shape (batch, n), compute the dot product
of each pair of rows without using an explicit loop.
Input:
A: A 2D tensor of shape (batch, n) B: A 2D tensor of shape (batch, n)
Output: A 1D tensor of shape (batch,) where each element is the dot product of the corresponding rows.
API Reference:
jax.vmap(fn)(A, B) torch.vmap(fn)(A, B)