We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Cumulative Trapezoidal Integration
Why this matters
The trapezoidal rule approximates a definite integral by summing trapezoid areas between adjacent sample points:
โซ y dx โ ฮฃ (y[i] + y[i+1]) / 2 * dx
The cumulative version returns the running total after each step โ useful for:
- ROC/AUC computation โ area under the ROC curve is a cumulative sum.
- Signal processing โ running integral of a sampled waveform.
- Physics simulation โ cumulative displacement from velocity samples.
- Time-series accumulation โ total rainfall, energy consumed, etc.
JAX does not ship a built-in cumtrapz. You build it from two primitives:
slice arithmetic (y[:-1], y[1:]) and jnp.cumsum.
Worked mini-example
import jax.numpy as jnp
y = jnp.array([0.0, 1.0, 2.0, 3.0])
dx = 1.0
trap_areas = (y[:-1] + y[1:]) / 2 * dx
# = [(0+1)/2, (1+2)/2, (2+3)/2] * 1.0
# = [0.5, 1.5, 2.5]
result = jnp.cumsum(trap_areas)
# = [0.5, 2.0, 4.5]
Output length is N - 1 (one fewer than input), because each element represents the area between a pair of adjacent samples.
Common pitfalls
-
Output length N instead of N - 1 โ returning
cumsumof something length-N is wrong; each trapezoid needs two points. -
Forgetting the / 2 โ
(y[i] + y[i+1]) * dxinstead of(y[i] + y[i+1]) / 2 * dxdoubles every area. -
Non-uniform spacing โ this implementation assumes uniform
dx. For non-uniform x coordinates usejnp.trapezoid(y, x)(total only, no cumulative) or compute variable-width trapezoids manually. -
Off-by-one on slicing โ
y[:-1]gives all but the last;y[1:]gives all but the first. These two slices are the left and right edges of each trapezoid.
Problem
Implement cumulative_integral(y, dx) โ the cumulative trapezoidal integral.
-
y: 1-D jax array (N values). -
dx: scalar โ uniform spacing between samples. - Returns: 1-D array (N - 1,) โ running cumulative integral after each step.
Hints
jax
trapz
integration
Sign in to attempt this problem and view the solution.