hard primitives

Cumulative Trapezoidal Integration

Why this matters

The trapezoidal rule approximates a definite integral by summing trapezoid areas between adjacent sample points:

โˆซ y dx  โ‰ˆ  ฮฃ  (y[i] + y[i+1]) / 2 * dx

The cumulative version returns the running total after each step โ€” useful for:

  • ROC/AUC computation โ€” area under the ROC curve is a cumulative sum.
  • Signal processing โ€” running integral of a sampled waveform.
  • Physics simulation โ€” cumulative displacement from velocity samples.
  • Time-series accumulation โ€” total rainfall, energy consumed, etc.

JAX does not ship a built-in cumtrapz. You build it from two primitives: slice arithmetic (y[:-1], y[1:]) and jnp.cumsum.

Worked mini-example

import jax.numpy as jnp

y  = jnp.array([0.0, 1.0, 2.0, 3.0])
dx = 1.0

trap_areas = (y[:-1] + y[1:]) / 2 * dx
# = [(0+1)/2, (1+2)/2, (2+3)/2] * 1.0
# = [0.5, 1.5, 2.5]

result = jnp.cumsum(trap_areas)
# = [0.5, 2.0, 4.5]

Output length is N - 1 (one fewer than input), because each element represents the area between a pair of adjacent samples.

Common pitfalls

  • Output length N instead of N - 1 โ€” returning cumsum of something length-N is wrong; each trapezoid needs two points.
  • Forgetting the / 2 โ€” (y[i] + y[i+1]) * dx instead of (y[i] + y[i+1]) / 2 * dx doubles every area.
  • Non-uniform spacing โ€” this implementation assumes uniform dx. For non-uniform x coordinates use jnp.trapezoid(y, x) (total only, no cumulative) or compute variable-width trapezoids manually.
  • Off-by-one on slicing โ€” y[:-1] gives all but the last; y[1:] gives all but the first. These two slices are the left and right edges of each trapezoid.

Problem

Implement cumulative_integral(y, dx) โ€” the cumulative trapezoidal integral.

  • y: 1-D jax array (N values).
  • dx: scalar โ€” uniform spacing between samples.
  • Returns: 1-D array (N - 1,) โ€” running cumulative integral after each step.

Hints

jax trapz integration

Sign in to attempt this problem and view the solution.