[Misc] Fix flashinfer related tests (#33462)
Signed-off-by: esmeetu <jasonailu87@gmail.com>
This commit is contained in:
@@ -74,7 +74,7 @@ def get_ref_results(
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
|
||||
@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm"])
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_nvfp4_gemm(
|
||||
|
||||
@@ -174,7 +174,7 @@ def test_static_fp8_quant_group_2d(
|
||||
f"group_shape ({group_shape[0]}, {group_shape[1]})"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
set_random_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = scaled_quantize(
|
||||
@@ -202,7 +202,7 @@ def test_static_fp8_quant_1d_scale(
|
||||
group_shape: tuple[int, int],
|
||||
) -> None:
|
||||
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
|
||||
current_platform.seed_everything(seed)
|
||||
set_random_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale_2d = scaled_quantize(
|
||||
|
||||
Reference in New Issue
Block a user