[Perf] Eliminate padding and slicing op for GPT-OSS with Flashinfer MXFP4 MXFP8 MoE (#30647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2026-03-18 23:01:26 +08:00
committed by GitHub
parent c373b5c00d
commit 296839a1b0
6 changed files with 40 additions and 3 deletions

View File

@@ -101,6 +101,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None
@property
def skip_forward_padding(self) -> bool:
"""Whether to skip the padding in the forward before applying the moe method."""
return False
@property
def supports_eplb(self) -> bool:
return False

View File

@@ -415,7 +415,10 @@ class DefaultMoERunner(MoERunner):
# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim = hidden_states.shape[-1]
if self.moe_config.hidden_dim != transformed_hidden_dim:
if (
not self.quant_method.skip_forward_padding
and self.moe_config.hidden_dim != transformed_hidden_dim
):
hidden_states = F.pad(
hidden_states,
(0, self.moe_config.hidden_dim - transformed_hidden_dim),

View File

@@ -294,6 +294,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_kernel: mk.FusedMoEKernel | None = None
@property
def skip_forward_padding(self) -> bool:
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
# so can skip the padding in the forward before applying the moe method
return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
def create_weights(
self,
layer: torch.nn.Module,
@@ -1130,9 +1136,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
# x_quant is padded in hidden dimension with alignment=256
x_quant, x_scale = mxfp8_quantize(
x,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
# output with original unpadded hidden size
output = torch.empty_like(x)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
@@ -1161,6 +1175,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
routing_method_type=1 if layer.renormalize else 0,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK: