Replace elements in a tensor where a boolean mask is True with a given fill value.
Input:
x: A tensor of any shape mask: A boolean tensor of the same shape (1.0 = True, 0.0 = False) fill_value: The value to fill where mask is True
Output: A tensor of the same shape with masked positions replaced by fill_value.
API Reference:
x.masked_fill(mask, fill_value) or torch.where(mask, fill_value, x) jnp.where(mask, fill_value, x)