diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 873f92cfe..5716c95bb 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -82,6 +82,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): f"attention backend '{attn_backend.backend.name}'" ) + # TODO: remove this after finishing migration from envs to model kwargs + if model_name == "openai/gpt-oss-20b": + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index 9d6c20264..1a5f18cc0 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -162,3 +162,12 @@ deepseek_v3_fp8 = ModelFusionInfo( # async_tp=n_layers * 2, ), ) + +gpt_oss_20b = ModelFusionInfo( + model_name="openai/gpt-oss-20b", + matches=lambda n_layers: Matches( + ar_rms_fusion=n_layers * 2 + 1, + sequence_parallel=n_layers * 2 + 1, + async_tp=n_layers * 2, + ), +) diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 8ffadbfaf..301409b2b 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -20,6 +20,7 @@ from .models import ( FLASHINFER_MLA_ATTN, TRITON_ATTN, deepseek_v3_fp8, + gpt_oss_20b, llama3_8b, llama3_8b_fp4, llama3_8b_fp8, @@ -158,7 +159,7 @@ def test_tp2_ar_rms_fp4_fusions( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", - [llama3_8b, qwen3_a3b], + [llama3_8b, qwen3_a3b, gpt_oss_20b], ) @pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) @pytest.mark.parametrize("n_layers", [4]) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 88cd173fe..f6a303e79 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -101,6 +101,11 @@ class FusedMoEMethodBase(QuantizeMethodBase): return self.moe_kernel.prepare_finalize.topk_indices_dtype() return None + @property + def skip_forward_padding(self) -> bool: + """Whether to skip the padding in the forward before applying the moe method.""" + return False + @property def supports_eplb(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index b6313776e..12b560493 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -415,7 +415,10 @@ class DefaultMoERunner(MoERunner): # This is the dimension after transform (for routed expert output slicing) transformed_hidden_dim = hidden_states.shape[-1] - if self.moe_config.hidden_dim != transformed_hidden_dim: + if ( + not self.quant_method.skip_forward_padding + and self.moe_config.hidden_dim != transformed_hidden_dim + ): hidden_states = F.pad( hidden_states, (0, self.moe_config.hidden_dim - transformed_hidden_dim), diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1ad024a6f..f992d0f86 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -294,6 +294,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # Initialized in process_weights_after_loading for CUTLASS/SM90 backends self.moe_kernel: mk.FusedMoEKernel | None = None + @property + def skip_forward_padding(self) -> bool: + # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant + # so can skip the padding in the forward before applying the moe method + return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + def create_weights( self, layer: torch.nn.Module, @@ -1130,9 +1136,17 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: from flashinfer import mxfp8_quantize - x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + # x_quant is padded in hidden dimension with alignment=256 + x_quant, x_scale = mxfp8_quantize( + x, + is_sf_swizzled_layout=False, + alignment=256, + ) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) + # output with original unpadded hidden size + output = torch.empty_like(x) + trtllm_gen_output = trtllm_fp4_block_scale_moe( routing_logits=router_logits.to(torch.bfloat16), routing_bias=None, @@ -1161,6 +1175,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): routing_method_type=1 if layer.renormalize else 0, do_finalize=True, tune_max_num_tokens=max(self.max_capture_size, 1), + output=output, )[0] return trtllm_gen_output elif self.mxfp4_backend == Mxfp4Backend.CK: