Use JIT compilation to optimize a function that computes f(x) = sin(x)^2 + cos(x)^2.
This should always return 1.0 for any input (by the Pythagorean identity), but the goal is to demonstrate JIT compilation, not simplify the math.
Input: A tensor x of any shape.
Output: A tensor of the same shape where every element is 1.0 (within floating-point tolerance).
API Reference:
jax.jit(fn) torch.compile(fn) (PyTorch 2.0+)