We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
Clip and Extrema
Why this matters
Clipping and elementwise extrema are ubiquitous in ML:
-
PPO (policy gradients): clip the probability ratio to
[1-ε, 1+ε]. -
Gradient clipping:
jnp.clip(grad_norm, 0, max_norm)followed by rescaling. -
Reward shaping: clip rewards to
[-R, R], then subtract a baseline, then floor at 0. -
ReLU and friends:
jnp.maximum(x, 0)is a ReLU;jnp.minimum(x, 0)gives the negative part.
JAX provides:
-
jnp.clip(x, lo, hi)— two-sided clamp. -
jnp.maximum(a, b)— elementwise supremum (broadcast-compatible). -
jnp.minimum(a, b)— elementwise infimum.
Worked mini-example
import jax.numpy as jnp
x = jnp.array([-2.0, 0.5, 3.0])
clipped = jnp.clip(x, 0.0, 1.0) # [0.0, 0.5, 1.0]
shifted = clipped - 0.3 # [-0.3, 0.2, 0.7]
floored = jnp.maximum(shifted, 0.0) # [0.0, 0.2, 0.7]
Common pitfalls
-
jnp.maxvsjnp.maximum:jnp.max(x)reduces to a scalar;jnp.maximum(x, 0)is elementwise. They are very different. -
jnp.clipvsjnp.minimum/jnp.maximum:jnp.clip(x, lo, hi)is sugar forjnp.minimum(jnp.maximum(x, lo), hi). -
Scalar vs array:
jnp.maximum(x, 0)broadcasts the scalar0to matchx. Works fine. Don’t confuse it withjnp.maximum(x)which doesn’t exist (needs two args). -
NaN propagation:
jnp.maximum(NaN, 0)returns NaN. If inputs may contain NaN, clean them first.
Problem
Implement clip_and_subtract_baseline(x, lo, hi, baseline) that:
-
Clips
xto[lo, hi]. -
Subtracts
baselinefrom the clipped values. - Floors at 0 (elementwise max with 0).
Returns a 1-D array of the same shape as x.
Two illustrative examples (not from the test set):
-
x = [0.0, 0.5, 2.0],lo = 0.0,hi = 1.0,baseline = 0.0: clip → [0, 0.5, 1]; subtract 0 → [0, 0.5, 1]; max(., 0) → [0, 0.5, 1]. -
x = [-1.0, 3.0],lo = 0.0,hi = 2.0,baseline = 1.5: clip → [0, 2]; subtract 1.5 → [-1.5, 0.5]; max(., 0) → [0, 0.5].
Hints
jax
clip
minimum
maximum
Sign in to attempt this problem and view the solution.