medium primitives

Linear Interpolation

Why this matters

jnp.interp(x_query, x_known, y_known) performs piecewise linear interpolation. Given a set of known (x_known, y_known) points, it estimates y at arbitrary query positions by linearly connecting adjacent known points.

This is used in:

  • Data resampling โ€” upsample or downsample a time series to a different grid.
  • Smooth function approximation โ€” represent a function with a lookup table.
  • Simulation of analog signals โ€” reconstruct a continuous signal from samples.

Key facts:

  • x_known must be sorted in ascending order (the requirement is silent โ€” wrong order produces garbage without error).
  • Out-of-range x_query values are clamped to the boundary y values (no extrapolation, unlike scipy).
  • Output shape matches x_query.

Worked mini-example

import jax.numpy as jnp

x_known = jnp.array([0.0, 1.0, 2.0])
y_known = jnp.array([0.0, 2.0, 0.0])
x_query = jnp.array([0.5, 1.0, 1.5])
out = jnp.interp(x_query, x_known, y_known)
# out = [1.0, 2.0, 1.0]

Common pitfalls

  • Unsorted x_known โ€” jnp.interp silently returns wrong results if x_known is not sorted ascending. Always sort before calling.
  • Clamping, not extrapolation โ€” query points outside [x_known[0], x_known[-1]] return the boundary y value, not an extrapolated estimate. This differs from scipyโ€™s interp1d with fill_value='extrapolate'.
  • Different from scipy โ€” jax.numpy.interp does not support keyword arguments like left, right, or period in all versions; check the JAX docs.

Problem

Implement interp_at(x_query, x_known, y_known) that performs piecewise linear interpolation.

  • x_query: 1-D jax array โ€” points to evaluate at.
  • x_known: 1-D jax array โ€” known x coordinates (must be sorted ascending).
  • y_known: 1-D jax array โ€” known y values, same length as x_known.
  • Returns: 1-D array same shape as x_query โ€” interpolated y values.

Hints

jax interp

Sign in to attempt this problem and view the solution.