[Perf] Eliminate padding and slicing op for GPT-OSS with Flashinfer MXFP4 MXFP8 MoE (#30647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2026-03-18 23:01:26 +08:00
committed by GitHub
parent c373b5c00d
commit 296839a1b0
6 changed files with 40 additions and 3 deletions

View File

@@ -82,6 +82,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"attention backend '{attn_backend.backend.name}'"
)
# TODO: remove this after finishing migration from envs to model kwargs
if model_name == "openai/gpt-oss-20b":
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

View File

@@ -162,3 +162,12 @@ deepseek_v3_fp8 = ModelFusionInfo(
# async_tp=n_layers * 2,
),
)
gpt_oss_20b = ModelFusionInfo(
model_name="openai/gpt-oss-20b",
matches=lambda n_layers: Matches(
ar_rms_fusion=n_layers * 2 + 1,
sequence_parallel=n_layers * 2 + 1,
async_tp=n_layers * 2,
),
)

View File

@@ -20,6 +20,7 @@ from .models import (
FLASHINFER_MLA_ATTN,
TRITON_ATTN,
deepseek_v3_fp8,
gpt_oss_20b,
llama3_8b,
llama3_8b_fp4,
llama3_8b_fp8,
@@ -158,7 +159,7 @@ def test_tp2_ar_rms_fp4_fusions(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b],
[llama3_8b, qwen3_a3b, gpt_oss_20b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize("n_layers", [4])