easy primitives

Polynomial Evaluation via polyval

Why this matters

jnp.polyval(coeffs, x_pts) evaluates a polynomial at an array of points. Coefficients are given in descending power order โ€” the same convention as jnp.polyfit. That is: coeffs[0] * x^(n-1) + coeffs[1] * x^(n-2) + ... + coeffs[-1].

polyval is the natural complement to polyfit:

  • Prediction โ€” evaluate a fitted polynomial at new x values.
  • Function approximation โ€” quickly compute polynomial models at many points.
  • Visualization โ€” generate smooth curve points for plotting.

The function is vectorized over x_pts โ€” pass a 1-D array and get one output per input point.

Worked mini-example

import jax.numpy as jnp

coeffs = jnp.array([1.0, 0.0, 0.0])   # x^2 (descending: [1, 0, 0])
x_pts  = jnp.array([0.0, 1.0, 2.0, 3.0])
out = jnp.polyval(coeffs, x_pts)
# out = [0.0, 1.0, 4.0, 9.0]

Common pitfalls

  • Ascending vs descending โ€” jnp.polyval expects descending order. If you accidentally pass ascending-order coefficients (e.g., from jnp.Polynomial), the evaluation will be wrong.
  • Scalar vs array x โ€” works on both; for a single point, wrap in an array or expect a scalar output.
  • No in-place mutations โ€” JAX arrays are immutable; polyval returns a new array.

Problem

Implement polyval_at(coeffs, x_pts) that evaluates a polynomial at each point in x_pts.

  • coeffs: 1-D jax array of shape (deg+1,) โ€” coefficients in descending power order.
  • x_pts: 1-D jax array โ€” points to evaluate at.
  • Returns: 1-D array same shape as x_pts โ€” polynomial values.

Hints

jax polynomial polyval

Sign in to attempt this problem and view the solution.