We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
framework
Mixed Precision Forward Pass
Implement a simple linear forward pass using mixed precision: perform the matrix multiplication in float16 (half precision) for speed, then cast the result back to float32.
Compute: output = (x_fp16 @ W_fp16 + bias_fp32) cast to float32.
Input:
-
x: A 2D input tensor of shape(batch, in_features)in float32 -
W: A 2D weight tensor of shape(in_features, out_features)in float32 -
bias: A 1D bias tensor of shape(out_features,)in float32
Output: A 2D tensor of shape (batch, out_features) in float32.
Steps:
-
Cast
xandWto float16 - Compute the matrix multiplication in float16
- Cast the result back to float32
- Add the float32 bias
API Reference:
-
PyTorch:
.half(),.float(),torch.float16 -
JAX:
.astype(jnp.float16),.astype(jnp.float32)
Hints
mixed-precision
float16
torch.half
jnp.float16
performance
Sign in to attempt this problem and view the solution.