Add support for ModelOpt MXFP8 MoE models (#35986)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2026-03-08 22:00:05 +02:00
committed by GitHub
parent 4497431df6
commit 0a6a3a1290
4 changed files with 597 additions and 18 deletions

View File

@@ -20,6 +20,8 @@ TRTLLM_GEN_MXFP4_AVAILABLE = (
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
)
TRTLLM_GEN_MXFP8_AVAILABLE = TRTLLM_GEN_MXFP4_AVAILABLE
HOPPER_MXFP4_BF16_AVAILABLE = (
current_platform.is_cuda()
and current_platform.is_device_capability(90)
@@ -34,9 +36,15 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
shuffle_matrix_a,
shuffle_matrix_sf_a,
trtllm_fp4_block_scale_moe,
trtllm_fp8_block_scale_moe,
)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
if TRTLLM_GEN_MXFP8_AVAILABLE:
from flashinfer.fused_moe.core import (
Fp8QuantizationType,
get_w2_permute_indices_with_cache,
)
@dataclass
@@ -160,6 +168,7 @@ def reference_moe(
beta,
limit,
act_type,
is_gated,
):
# renormalize routing
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
@@ -170,7 +179,12 @@ def reference_moe(
mlp1_weight = w13[expert_indices, ...]
mlp1_bias = bias13[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
if is_gated:
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
else:
# RELU2_NO_MUL: relu(x)^2
t = torch.relu(t)
t = t * t
if act_type == "mxfp8":
t_quantized, t_scale = mxfp8_quantize(
@@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
beta,
limit,
act_type,
is_gated=True,
)
ref_result[start_idx:end_idx].copy_(chunk_result)
@@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
beta,
limit,
"bf16",
is_gated=True,
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
@@ -890,6 +906,7 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
beta,
limit,
"mxfp8",
is_gated=True,
)
# Prepare inputs for FlashInfer CUTLASS fused MoE
@@ -965,3 +982,169 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
# Allow some mismatch due to MXFP4 quantization
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("is_gated", [True], ids=["gated"])
@pytest.mark.skipif(
not TRTLLM_GEN_MXFP8_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test",
)
def test_trtllm_gen_mxfp8_block_scale_moe(
topk: int,
num_experts: int,
num_tokens: int,
intermediate_size: int,
hidden_size: int,
is_gated: bool,
):
torch.manual_seed(42)
device = "cuda:0"
inter_size = intermediate_size * (2 if is_gated else 1)
hidden_states = (
torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) / 20
)
w13 = (
torch.randn(
num_experts,
inter_size,
hidden_size,
device=device,
dtype=torch.bfloat16,
)
/ 20
)
w2 = (
torch.randn(
num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16,
)
/ 20
)
router_logits = torch.rand(
num_tokens, num_experts, dtype=torch.float32, device=device
)
router_logits_kernel = router_logits.to(torch.bfloat16)
# Quantize weights to MXFP8 and normalize scales to [E, M, K//32].
w13_q, w13_scale = mxfp8_quantize(w13, is_sf_swizzled_layout=False)
w2_q, w2_scale = mxfp8_quantize(w2, is_sf_swizzled_layout=False)
if w13_scale.ndim == 1:
w13_scale = w13_scale.view(
num_experts,
inter_size,
hidden_size // 32,
)
if w2_scale.ndim == 1:
w2_scale = w2_scale.view(num_experts, hidden_size, intermediate_size // 32)
# Quantize activations to MXFP8.
hidden_states_q, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False
)
if hidden_states_scale.ndim == 1:
hidden_states_scale = hidden_states_scale.view(num_tokens, hidden_size // 32)
# Reference output using dequantized tensors + MXFP8 intermediate quantization.
w13_ref = mxfp8_dequantize(w13_q, w13_scale).to(torch.float32)
w2_ref = mxfp8_dequantize(w2_q, w2_scale).to(torch.float32)
hidden_states_ref = mxfp8_dequantize(hidden_states_q, hidden_states_scale).to(
torch.float32
)
bias13 = torch.zeros(
num_experts,
intermediate_size * (2 if is_gated else 1),
device=device,
)
bias2 = torch.zeros(num_experts, hidden_size, device=device)
ref = reference_moe(
router_logits_kernel.to(torch.float32),
topk,
num_experts,
hidden_states_ref,
w13_ref,
bias13,
w2_ref,
bias2,
alpha=1.0,
beta=0.0,
limit=None,
act_type="mxfp8",
is_gated=is_gated,
)
# Shuffle weights/scales with the same indexed layout used by TRTLLM kernels.
epilogue_tile_m = 128
gemm1_weights_shuffled = []
gemm1_scales_shuffled = []
gemm2_weights_shuffled = []
gemm2_scales_shuffled = []
for i in range(num_experts):
w13_rows = intermediate_size * (2 if is_gated else 1)
w13_interleaved = w13_q[i].clone().reshape(w13_rows, -1)
w13_scale_interleaved = w13_scale[i].clone().reshape(w13_rows, -1)
if is_gated:
w13_interleaved = reorder_rows_for_gated_act_gemm(w13_interleaved)
w13_scale_interleaved = reorder_rows_for_gated_act_gemm(
w13_scale_interleaved
)
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_interleaved.view(torch.uint8), epilogue_tile_m)
.contiguous()
.view(w13_q.dtype)
)
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_q[i].view(torch.uint8), epilogue_tile_m)
.contiguous()
.view(w2_q.dtype)
)
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(
w13_scale_interleaved.view(torch.uint8).reshape(w13_rows, -1),
epilogue_tile_m,
)
.contiguous()
.view(w13_scale.dtype)
)
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), epilogue_tile_m
)
.contiguous()
.view(w2_scale.dtype)
)
out = trtllm_fp8_block_scale_moe(
routing_logits=router_logits_kernel,
routing_bias=None,
hidden_states=hidden_states_q,
hidden_states_scale=hidden_states_scale,
gemm1_weights=torch.stack(gemm1_weights_shuffled),
gemm1_weights_scale=torch.stack(gemm1_scales_shuffled),
gemm2_weights=torch.stack(gemm2_weights_shuffled),
gemm2_weights_scale=torch.stack(gemm2_scales_shuffled),
num_experts=num_experts,
top_k=topk,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routing_method_type=1, # renormalize routing
use_shuffled_weight=True,
weight_layout=0, # MajorK
fp8_quantization_type=Fp8QuantizationType.MxFp8,
)
# Block-scale MXFP8 kernels are approximate; require majority close.
check_accuracy(ref, out, atol=0.1, rtol=0.85, percent=0.8)