We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Shape Polymorphism (Concept)
Why this matters
JAX source code is naturally shape-polymorphic: write jnp.sum(x) once
and it handles any 1-D array regardless of length. This is Python-level
polymorphism β the function works for any shape at the source level.
However, jax.jit is shape-monomorphic at the compilation level: each
distinct (shape, dtype) combination triggers a fresh XLA compilation and
stores a separate binary. Calling the same jitted function with shape=[3]
and then shape=[5] compiles twice.
For deployment β shipping a trained model where the batch size or
sequence length is unknown at export time β you need a single binary that
handles many shapes. jax.export supports this via symbolic dimensions:
import jax
from jax import export
def f(x):
return jnp.sum(x)
# Export with a symbolic batch dimension 'b'
b = export.symbolic_shape("b")
exported = export.export(jax.jit(f))(
jax.ShapeDtypeStruct((b,), jnp.float32)
)
The exported binary then handles any concrete value of b without
recompilation.
Worked mini-example
import jax.numpy as jnp
def polymorphic_sum(x):
return jnp.sum(x)
# Works for any 1-D shape β source-level polymorphism:
print(polymorphic_sum(jnp.array([1.0, 2.0, 3.0]))) # 6.0
print(polymorphic_sum(jnp.array([10.0, 20.0, 30.0, 40.0, 50.0]))) # 150.0
Common pitfalls
- Assuming jit handles dynamic shapes: it does not β it specializes on each distinct shape. Pass many different shapes in a hot loop and you will pay many compilation costs.
- Source polymorphism β JIT polymorphism: the source code works for all shapes, but each JIT-compiled specialization is shape-monomorphic.
-
jax.exportfor deployment: use it when you need a single compiled artifact for unknown shapes; it requires annotating symbolic dimensions.
Problem
Implement polymorphic_sum(x) that returns jnp.sum(x). This is
intentionally simple β the lesson is the conceptual distinction between
source-level shape polymorphism and JITβs shape-monomorphic compilation.
-
x: 1-D JAX array of any length.
Returns: scalar β jnp.sum(x).
Examples (not from the test set):
-
polymorphic_sum(jnp.array([1.0, 2.0, 3.0]))β6.0
Hints
Sign in to attempt this problem and view the solution.