easy primitives

JAX numpy vs PyTorch ops

Why this matters

JAX’s jnp namespace is intentionally API-compatible with NumPy. If you can use numpy, you mostly know jnp. PyTorch’s torch API has many of the same operations under different names: torch.matmuljnp.matmul; torch.transpose(t, dim0, dim1)jnp.transpose(arr, axes); torch.zeros_likejnp.zeros_like. Translating between PyTorch and JAX is mostly a search-replace exercise — once you know the few subtle differences.

PyTorch JAX
torch.matmul(a, b) jnp.matmul(a, b) (or a @ b)
torch.where(cond, t, f) jnp.where(cond, t, f)
torch.zeros_like(t) jnp.zeros_like(t)
torch.transpose(t, 0, 1) jnp.transpose(t, (1, 0)) (full perm tuple, not two axes)
t.reshape(-1) t.reshape(-1) (works in both)
torch.relu(t) jax.nn.relu(t) (or jnp.maximum(t, 0))
t.permute(2, 0, 1) jnp.transpose(t, (2, 0, 1))
t.unsqueeze(0) t[None] or jnp.expand_dims(t, 0)
t.squeeze(0) jnp.squeeze(t, 0)
t.detach() jax.lax.stop_gradient(t)
with torch.no_grad(): use jax.lax.stop_gradient selectively
t.cuda() jax.device_put(t, jax.devices('gpu')[0])

Worked mini-example

The biggest API gotcha: torch.transpose(t, 0, 1) swaps two axes; jnp.transpose takes a FULL axis-permutation tuple. For 2-D specifically, jnp.transpose(t, (1, 0)) swaps the two axes; t.T is equivalent and more idiomatic.

# PyTorch
import torch
y = torch.tensor([[1., 2.], [3., 4.]])
torch.transpose(y, 0, 1)   # [[1,3],[2,4]]

# JAX
import jax.numpy as jnp
y = jnp.array([[1., 2.], [3., 4.]])
jnp.transpose(y, (1, 0))   # [[1,3],[2,4]]  ← full perm tuple
y.T                         # same result, more idiomatic

Common pitfalls

  • torch.transpose vs jnp.transpose argument shape: torch takes two axis ints; jnp takes a tuple-of-all-axes permutation.
  • Float dtype defaults: both PyTorch and JAX default to float32 (when JAX x64 is disabled, which is the default config).
  • In-place torch_t.add_(...): NOT available in JAX — arrays are immutable. Use t.at[...].add(...) instead.
  • torch.argmax vs jnp.argmax: same name, but PyTorch’s dim= is axis= in JAX.
  • torch.cat vs jnp.concatenate: name differs.
  • Random ops have different APIs entirely: PyTorch uses global state; JAX uses explicit keys (covered earlier in this track).

Problem

Implement torch_to_jax_translation(x, w), the JAX equivalent of the following PyTorch snippet:

y = torch.matmul(x, w)
y = torch.where(y > 0, y, torch.zeros_like(y))   # ReLU
y = torch.transpose(y, 0, 1)
return y.reshape(-1)

x has shape (N, d_in) and w has shape (d_in, d_out). The function returns a 1-D array of shape (N * d_out,).

Two illustrative examples (not from the test set):

  • x = [[1., 0.], [0., 1.]], w = [[2., -1.], [3., 4.]]: matmul = [[2,-1],[3,4]]. ReLU = [[2,0],[3,4]]. Transpose = [[2,3],[0,4]]. Flatten → [2., 3., 0., 4.]

  • x = [[0., -1., 2.]], w = [[1.], [0.], [1.]]: matmul = [[2.]]. ReLU = [[2.]]. Transpose = [[2.]]. Flatten → [2.]

Hints

jax numpy pytorch-bridge

Sign in to attempt this problem and view the solution.