medium framework

JIT Compile a Function

Use JIT compilation to optimize a function that computes f(x) = sin(x)^2 + cos(x)^2.

This should always return 1.0 for any input (by the Pythagorean identity), but the goal is to demonstrate JIT compilation, not simplify the math.

Input: A tensor x of any shape.

Output: A tensor of the same shape where every element is 1.0 (within floating-point tolerance).

API Reference:

  • JAX: jax.jit(fn)
  • PyTorch: torch.compile(fn) (PyTorch 2.0+)

Hints

jit compilation jax.jit torch.compile
Detecting runtime...