[Perf] Eliminate padding and slicing op for GPT-OSS with Flashinfer MXFP4 MXFP8 MoE (#30647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -82,6 +82,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"attention backend '{attn_backend.backend.name}'"
|
||||
)
|
||||
|
||||
# TODO: remove this after finishing migration from envs to model kwargs
|
||||
if model_name == "openai/gpt-oss-20b":
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
@@ -162,3 +162,12 @@ deepseek_v3_fp8 = ModelFusionInfo(
|
||||
# async_tp=n_layers * 2,
|
||||
),
|
||||
)
|
||||
|
||||
gpt_oss_20b = ModelFusionInfo(
|
||||
model_name="openai/gpt-oss-20b",
|
||||
matches=lambda n_layers: Matches(
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
sequence_parallel=n_layers * 2 + 1,
|
||||
async_tp=n_layers * 2,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ from .models import (
|
||||
FLASHINFER_MLA_ATTN,
|
||||
TRITON_ATTN,
|
||||
deepseek_v3_fp8,
|
||||
gpt_oss_20b,
|
||||
llama3_8b,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
@@ -158,7 +159,7 @@ def test_tp2_ar_rms_fp4_fusions(
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b, qwen3_a3b],
|
||||
[llama3_8b, qwen3_a3b, gpt_oss_20b],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
|
||||
@@ -101,6 +101,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
@property
|
||||
def skip_forward_padding(self) -> bool:
|
||||
"""Whether to skip the padding in the forward before applying the moe method."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -415,7 +415,10 @@ class DefaultMoERunner(MoERunner):
|
||||
|
||||
# This is the dimension after transform (for routed expert output slicing)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if self.moe_config.hidden_dim != transformed_hidden_dim:
|
||||
if (
|
||||
not self.quant_method.skip_forward_padding
|
||||
and self.moe_config.hidden_dim != transformed_hidden_dim
|
||||
):
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
|
||||
|
||||
@@ -294,6 +294,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
@property
|
||||
def skip_forward_padding(self) -> bool:
|
||||
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
|
||||
# so can skip the padding in the forward before applying the moe method
|
||||
return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1130,9 +1136,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
# x_quant is padded in hidden dimension with alignment=256
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
x,
|
||||
is_sf_swizzled_layout=False,
|
||||
alignment=256,
|
||||
)
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
|
||||
|
||||
# output with original unpadded hidden size
|
||||
output = torch.empty_like(x)
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
@@ -1161,6 +1175,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
routing_method_type=1 if layer.renormalize else 0,
|
||||
do_finalize=True,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
output=output,
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.CK:
|
||||
|
||||
Reference in New Issue
Block a user