[Misc] Fix flashinfer related tests (#33462)

Signed-off-by: esmeetu <jasonailu87@gmail.com>
This commit is contained in:
Roy Wang
2026-02-01 05:10:24 +08:00
committed by GitHub
parent 1e86c802d4
commit 63c0889416
5 changed files with 9 additions and 8 deletions

View File

@@ -174,7 +174,7 @@ def test_static_fp8_quant_group_2d(
f"group_shape ({group_shape[0]}, {group_shape[1]})"
)
current_platform.seed_everything(seed)
set_random_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize(
@@ -202,7 +202,7 @@ def test_static_fp8_quant_1d_scale(
group_shape: tuple[int, int],
) -> None:
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
current_platform.seed_everything(seed)
set_random_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale_2d = scaled_quantize(