We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
JAX numpy vs PyTorch ops
Why this matters
JAX’s jnp namespace is intentionally API-compatible with NumPy. If you
can use numpy, you mostly know jnp. PyTorch’s torch API has many of the
same operations under different names: torch.matmul ≈ jnp.matmul;
torch.transpose(t, dim0, dim1) ≈ jnp.transpose(arr, axes);
torch.zeros_like ≈ jnp.zeros_like. Translating between PyTorch and JAX is
mostly a search-replace exercise — once you know the few subtle differences.
| PyTorch | JAX |
|---|---|
torch.matmul(a, b) |
jnp.matmul(a, b) (or a @ b) |
torch.where(cond, t, f) |
jnp.where(cond, t, f) |
torch.zeros_like(t) |
jnp.zeros_like(t) |
torch.transpose(t, 0, 1) |
jnp.transpose(t, (1, 0)) (full perm tuple, not two axes) |
t.reshape(-1) |
t.reshape(-1) (works in both) |
torch.relu(t) |
jax.nn.relu(t) (or jnp.maximum(t, 0)) |
t.permute(2, 0, 1) |
jnp.transpose(t, (2, 0, 1)) |
t.unsqueeze(0) |
t[None] or jnp.expand_dims(t, 0) |
t.squeeze(0) |
jnp.squeeze(t, 0) |
t.detach() |
jax.lax.stop_gradient(t) |
with torch.no_grad(): |
use jax.lax.stop_gradient selectively |
t.cuda() |
jax.device_put(t, jax.devices('gpu')[0]) |
Worked mini-example
The biggest API gotcha: torch.transpose(t, 0, 1) swaps two axes;
jnp.transpose takes a FULL axis-permutation tuple. For 2-D specifically,
jnp.transpose(t, (1, 0)) swaps the two axes; t.T is equivalent and more
idiomatic.
# PyTorch
import torch
y = torch.tensor([[1., 2.], [3., 4.]])
torch.transpose(y, 0, 1) # [[1,3],[2,4]]
# JAX
import jax.numpy as jnp
y = jnp.array([[1., 2.], [3., 4.]])
jnp.transpose(y, (1, 0)) # [[1,3],[2,4]] ← full perm tuple
y.T # same result, more idiomatic
Common pitfalls
-
torch.transposevsjnp.transposeargument shape: torch takes two axis ints; jnp takes a tuple-of-all-axes permutation. - Float dtype defaults: both PyTorch and JAX default to float32 (when JAX x64 is disabled, which is the default config).
-
In-place
torch_t.add_(...): NOT available in JAX — arrays are immutable. Uset.at[...].add(...)instead. -
torch.argmaxvsjnp.argmax: same name, but PyTorch’sdim=isaxis=in JAX. -
torch.catvsjnp.concatenate: name differs. - Random ops have different APIs entirely: PyTorch uses global state; JAX uses explicit keys (covered earlier in this track).
Problem
Implement torch_to_jax_translation(x, w), the JAX equivalent of the
following PyTorch snippet:
y = torch.matmul(x, w)
y = torch.where(y > 0, y, torch.zeros_like(y)) # ReLU
y = torch.transpose(y, 0, 1)
return y.reshape(-1)
x has shape (N, d_in) and w has shape (d_in, d_out). The function
returns a 1-D array of shape (N * d_out,).
Two illustrative examples (not from the test set):
-
x = [[1., 0.], [0., 1.]],w = [[2., -1.], [3., 4.]]: matmul = [[2,-1],[3,4]]. ReLU = [[2,0],[3,4]]. Transpose = [[2,3],[0,4]]. Flatten →[2., 3., 0., 4.] -
x = [[0., -1., 2.]],w = [[1.], [0.], [1.]]: matmul = [[2.]]. ReLU = [[2.]]. Transpose = [[2.]]. Flatten →[2.]
Hints
Sign in to attempt this problem and view the solution.