diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index a130f9acb..4ff1e590a 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -242,7 +242,7 @@ def test_reshape_and_cache_flash( value_cache_compact = permute_and_compact(value_cache) def convert_fp8_local(output, input, scale, kv_dtype): - fp8_input = input.view(torch.float8_e4m3fn) + fp8_input = input.view(current_platform.fp8_dtype()) if scale.numel() == 1: # per-tensor result = scaled_dequantize( fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype