medium primitives

Shape Polymorphism (Concept)

Why this matters

JAX source code is naturally shape-polymorphic: write jnp.sum(x) once and it handles any 1-D array regardless of length. This is Python-level polymorphism β€” the function works for any shape at the source level.

However, jax.jit is shape-monomorphic at the compilation level: each distinct (shape, dtype) combination triggers a fresh XLA compilation and stores a separate binary. Calling the same jitted function with shape=[3] and then shape=[5] compiles twice.

For deployment β€” shipping a trained model where the batch size or sequence length is unknown at export time β€” you need a single binary that handles many shapes. jax.export supports this via symbolic dimensions:

import jax
from jax import export

def f(x):
    return jnp.sum(x)

# Export with a symbolic batch dimension 'b'
b = export.symbolic_shape("b")
exported = export.export(jax.jit(f))(
    jax.ShapeDtypeStruct((b,), jnp.float32)
)

The exported binary then handles any concrete value of b without recompilation.

Worked mini-example

import jax.numpy as jnp

def polymorphic_sum(x):
    return jnp.sum(x)

# Works for any 1-D shape β€” source-level polymorphism:
print(polymorphic_sum(jnp.array([1.0, 2.0, 3.0])))        # 6.0
print(polymorphic_sum(jnp.array([10.0, 20.0, 30.0, 40.0, 50.0])))  # 150.0

Common pitfalls

  • Assuming jit handles dynamic shapes: it does not β€” it specializes on each distinct shape. Pass many different shapes in a hot loop and you will pay many compilation costs.
  • Source polymorphism β‰  JIT polymorphism: the source code works for all shapes, but each JIT-compiled specialization is shape-monomorphic.
  • jax.export for deployment: use it when you need a single compiled artifact for unknown shapes; it requires annotating symbolic dimensions.

Problem

Implement polymorphic_sum(x) that returns jnp.sum(x). This is intentionally simple β€” the lesson is the conceptual distinction between source-level shape polymorphism and JIT’s shape-monomorphic compilation.

  • x: 1-D JAX array of any length.

Returns: scalar β€” jnp.sum(x).

Examples (not from the test set):

  • polymorphic_sum(jnp.array([1.0, 2.0, 3.0])) β†’ 6.0

Hints

jax shape-polymorphism jax.export

Sign in to attempt this problem and view the solution.