[Quant][Feature] Support online MXFP8 quantization for MoE and dense models (#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
This commit is contained in:
EdalatiAli
2026-03-16 18:07:39 -04:00
committed by GitHub
parent fd4d96302a
commit e5b807607c
10 changed files with 747 additions and 56 deletions

View File

@@ -305,6 +305,81 @@ def align_fp8_moe_weights_for_fi(
return padded_w13, padded_w2, padded_intermediate
def _shuffle_mxfp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
is_gated: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel.
Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py:
1. reorder_rows_for_gated_act_gemm (interleave gate/up rows)
2. shuffle_matrix_a (weight data layout shuffle)
3. shuffle_matrix_sf_a (scale factor layout shuffle)
"""
from flashinfer import (
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
epilogue_tile_m = 128
num_experts = w13.shape[0]
intermediate_size = w13.shape[1] // 2
hidden_size = w13.shape[2]
w13_interleaved: list[torch.Tensor] = []
w13_scale_interleaved: list[torch.Tensor] = []
for i in range(num_experts):
if is_gated:
w13_interleaved.append(
reorder_rows_for_gated_act_gemm(
w13[i].reshape(2 * intermediate_size, -1)
)
)
w13_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(
w13_scale[i].reshape(2 * intermediate_size, -1)
)
)
else:
w13_interleaved.append(w13[i])
w13_scale_interleaved.append(w13_scale[i])
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
w13_scale_shuffled: list[torch.Tensor] = []
w2_scale_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
w13_shuffled.append(
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
)
w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m))
w13_scale_shuffled.append(
shuffle_matrix_sf_a(
w13_scale_interleaved[i]
.view(torch.uint8)
.reshape(2 * intermediate_size, -1),
epilogue_tile_m,
)
)
w2_scale_shuffled.append(
shuffle_matrix_sf_a(
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
epilogue_tile_m,
)
)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape)
w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape)
return w13_out, w2_out, w13_scale_out, w2_scale_out
def prepare_fp8_moe_layer_for_fi(
layer: torch.nn.Module,
w13: torch.Tensor,
@@ -314,7 +389,7 @@ def prepare_fp8_moe_layer_for_fi(
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor | None,
is_trtllm: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
@@ -329,10 +404,33 @@ def prepare_fp8_moe_layer_for_fi(
block_quant = (
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
is_gated = layer.activation.is_gated
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
if is_mxfp8 and is_trtllm:
# FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores
# [gate; up]. Swap both weights and scales before interleaving.
if layer.moe_config.is_act_and_mul:
w13 = swap_w13_to_w31(w13)
# Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight;
# reshape to 3D so swap_w13_to_w31 can flip the two halves,
# then flatten back.
if w13_scale.ndim == 2:
num_rows = w13.shape[1] # 2 * intermediate_size
w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1)
w13_scale = swap_w13_to_w31(w13_scale)
w13_scale = w13_scale.reshape(w13_scale.shape[0], -1)
else:
w13_scale = swap_w13_to_w31(w13_scale)
w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights(
w13, w2, w13_scale, w2_scale, is_gated
)
return w13, w2, w13_scale, w2_scale
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
is_gated = layer.activation.is_gated
if not block_quant:
min_alignment = 16 if is_gated else 128
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
@@ -369,4 +467,4 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
return w13, w2, w13_scale
return w13, w2, w13_scale, w2_scale