[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4 (#33506)

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
amitz-nv
2026-02-12 23:06:58 +02:00
committed by GitHub
parent fac4e96940
commit f120bd42d3
5 changed files with 197 additions and 45 deletions

View File

@@ -71,7 +71,8 @@ def quant_fp8_per_tensor_batches(a):
for i in range(num_batches):
a_fp8, a_global_sf = input_to_float8(a[i])
a_global_sf = 1.0 / a_global_sf
if a_global_sf.numel() == 1:
a_global_sf = a_global_sf.view(1, 1)
a_quant.append(a_fp8)
a_scales.append(a_global_sf)
@@ -81,6 +82,20 @@ def quant_fp8_per_tensor_batches(a):
return result_a_quant, result_a_scales
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
match_ratio = close.float().mean()
assert match_ratio >= percent, (
f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
)
mismatch_percent = 1.0 - match_ratio.item()
assert mismatch_percent <= 1 - percent, (
f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
f"{1 - percent:.4f}"
)
@dataclass
class TestData:
hidden_states: torch.Tensor
@@ -104,14 +119,16 @@ class TestData:
is_gated = activation.is_gated
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
w13 = (
torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
)
/ 10
)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
@@ -124,14 +141,16 @@ class TestData:
layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
layer.activation = activation
# Setup dummy config.
layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_gated:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_trtllm:
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight
layer.w13_weight, layer.w2_weight, is_gated
)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
@@ -162,12 +181,14 @@ class TestData:
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
def test_flashinfer_per_tensor_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
activation: MoEActivation,
monkeypatch,
):
if not current_platform.has_device_capability(100):
@@ -175,7 +196,9 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=True, activation=activation
)
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
@@ -200,7 +223,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation=MoEActivation.SILU,
activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
@@ -219,7 +242,13 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
check_accuracy(
ref_output=output,
actual_output=flashinfer_output,
atol=0.1,
rtol=0.85,
percent=0.925,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@@ -320,8 +349,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
expert_map=None,
apply_router_weight_on_input=True,
)
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
check_accuracy(
ref_output=output,
actual_output=flashinfer_cutlass_output,
atol=0.1,
rtol=0.85,
percent=0.925,
)