[Perf] Eliminate padding and slicing op for GPT-OSS with Flashinfer MXFP4 MXFP8 MoE (#30647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -101,6 +101,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
@property
|
||||
def skip_forward_padding(self) -> bool:
|
||||
"""Whether to skip the padding in the forward before applying the moe method."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -415,7 +415,10 @@ class DefaultMoERunner(MoERunner):
|
||||
|
||||
# This is the dimension after transform (for routed expert output slicing)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if self.moe_config.hidden_dim != transformed_hidden_dim:
|
||||
if (
|
||||
not self.quant_method.skip_forward_padding
|
||||
and self.moe_config.hidden_dim != transformed_hidden_dim
|
||||
):
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
|
||||
|
||||
@@ -294,6 +294,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
@property
|
||||
def skip_forward_padding(self) -> bool:
|
||||
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
|
||||
# so can skip the padding in the forward before applying the moe method
|
||||
return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1130,9 +1136,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
# x_quant is padded in hidden dimension with alignment=256
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
x,
|
||||
is_sf_swizzled_layout=False,
|
||||
alignment=256,
|
||||
)
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
|
||||
|
||||
# output with original unpadded hidden size
|
||||
output = torch.empty_like(x)
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
@@ -1161,6 +1175,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
routing_method_type=1 if layer.renormalize else 0,
|
||||
do_finalize=True,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
output=output,
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.CK:
|
||||
|
||||
Reference in New Issue
Block a user