easy primitives

Clip and Extrema

Why this matters

Clipping and elementwise extrema are ubiquitous in ML:

  • PPO (policy gradients): clip the probability ratio to [1-ε, 1+ε].
  • Gradient clipping: jnp.clip(grad_norm, 0, max_norm) followed by rescaling.
  • Reward shaping: clip rewards to [-R, R], then subtract a baseline, then floor at 0.
  • ReLU and friends: jnp.maximum(x, 0) is a ReLU; jnp.minimum(x, 0) gives the negative part.

JAX provides:

  • jnp.clip(x, lo, hi) — two-sided clamp.
  • jnp.maximum(a, b) — elementwise supremum (broadcast-compatible).
  • jnp.minimum(a, b) — elementwise infimum.

Worked mini-example

import jax.numpy as jnp

x        = jnp.array([-2.0, 0.5, 3.0])
clipped  = jnp.clip(x, 0.0, 1.0)        # [0.0, 0.5, 1.0]
shifted  = clipped - 0.3                 # [-0.3, 0.2, 0.7]
floored  = jnp.maximum(shifted, 0.0)     # [0.0, 0.2, 0.7]

Common pitfalls

  • jnp.max vs jnp.maximum: jnp.max(x) reduces to a scalar; jnp.maximum(x, 0) is elementwise. They are very different.
  • jnp.clip vs jnp.minimum/jnp.maximum: jnp.clip(x, lo, hi) is sugar for jnp.minimum(jnp.maximum(x, lo), hi).
  • Scalar vs array: jnp.maximum(x, 0) broadcasts the scalar 0 to match x. Works fine. Don’t confuse it with jnp.maximum(x) which doesn’t exist (needs two args).
  • NaN propagation: jnp.maximum(NaN, 0) returns NaN. If inputs may contain NaN, clean them first.

Problem

Implement clip_and_subtract_baseline(x, lo, hi, baseline) that:

  1. Clips x to [lo, hi].
  2. Subtracts baseline from the clipped values.
  3. Floors at 0 (elementwise max with 0).

Returns a 1-D array of the same shape as x.

Two illustrative examples (not from the test set):

  • x = [0.0, 0.5, 2.0], lo = 0.0, hi = 1.0, baseline = 0.0: clip → [0, 0.5, 1]; subtract 0 → [0, 0.5, 1]; max(., 0) → [0, 0.5, 1].

  • x = [-1.0, 3.0], lo = 0.0, hi = 2.0, baseline = 1.5: clip → [0, 2]; subtract 1.5 → [-1.5, 0.5]; max(., 0) → [0, 0.5].

Hints

jax clip minimum maximum

Sign in to attempt this problem and view the solution.