diff --git a/tests/compile/test_fuse_act_padding.py b/tests/compile/test_fuse_act_padding.py new file mode 100644 index 000000000..d2670cd64 --- /dev/null +++ b/tests/compile/test_fuse_act_padding.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest +import torch + +import vllm.config +from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.utils import rocm_unquantized_gemm + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + def __init__( + self, + num_layers: int, + hidden_size: int, + num_local_experts: int, + x_pad_to_multiple: int, + ): + super().__init__() + self.num_layers = num_layers + self.hidden_size = hidden_size + self.x_pad_to_multiple = x_pad_to_multiple + self.pad_dim = x_pad_to_multiple - (hidden_size % x_pad_to_multiple) + + self.norm = [RMSNorm(hidden_size, eps=1e-5) for _ in range(num_layers)] + self.router = [ + torch.nn.Linear(hidden_size, num_local_experts) for _ in range(4) + ] + + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) + all_router_logits = [] + for layer in range(self.num_layers): + x = x[:, : self.hidden_size] + x, resid = self.norm[layer](x, resid) + router_logits = rocm_unquantized_gemm( + self, x, self.router[layer].weight, self.router[layer].bias + ) + x = torch.nn.functional.pad( + x, (0, self.pad_dim), mode="constant", value=0.0 + ) + all_router_logits.append(router_logits) + + return x, resid, *all_router_logits + + def ops_in_model_before(self): + return [ + rocm_aiter_ops.get_rmsnorm_fused_add_op(), + torch.ops.aten.constant_pad_nd, + ] + + def ops_in_model_after(self): + return [rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_layers", [3]) +@pytest.mark.parametrize("hidden_size", [2880]) +@pytest.mark.parametrize("num_local_experts", [128]) +@pytest.mark.parametrize("x_pad_to_multiple", [256]) +@pytest.mark.skipif( + not is_aiter_found_and_supported(), + reason="Only test on ROCm with AITER installed and supported", +) +def test_fuse_act_padding( + dtype: torch.dtype, + num_layers: int, + hidden_size: int, + num_local_experts: int, + x_pad_to_multiple: int, + monkeypatch: pytest.MonkeyPatch, +): + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm"], + pass_config=PassConfig(fuse_act_padding=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterTritonAddRMSNormPadFusionPass, + ) + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + fusion_pass = RocmAiterTritonAddRMSNormPadFusionPass(vllm_config) + passes = [ + NoOpEliminationPass(vllm_config), + fusion_pass, + PostCleanupPass(vllm_config), + ] + backend = TestBackend(*passes) + model = TestModel(num_layers, hidden_size, num_local_experts, x_pad_to_multiple) + + x = torch.rand(1, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + outputs_unfused = model(x) + + model_fused = torch.compile(model, backend=backend) + outputs_fused = model_fused(x) + + torch.testing.assert_close(outputs_unfused, outputs_fused) + + assert fusion_pass.matched_count == num_layers + + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 26afb0203..e4a4fef23 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -410,7 +410,7 @@ def test_aiter_fusion_rmsnorm_quant( ) with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: - from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass + from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass m.setenv("VLLM_ROCM_USE_AITER", "1") @@ -420,7 +420,7 @@ def test_aiter_fusion_rmsnorm_quant( torch.set_default_dtype(dtype) torch.manual_seed(1) - fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) model = TestModel( hidden_size=hidden_size, diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 419ac6c50..8dea38248 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -790,6 +790,41 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( return x_fp8, out_bs +def _rocm_aiter_triton_add_rmsnorm_pad_impl( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + residual: torch.Tensor, + x_pad_to_multiple: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad + + return fused_add_rmsnorm_pad( + x, + weight, + variance_epsilon, + residual, + x_pad_to_multiple=x_pad_to_multiple, + ) + + +def _rocm_aiter_triton_add_rmsnorm_pad_fake( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + residual: torch.Tensor, + x_pad_to_multiple: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + if x_pad_to_multiple > 0: + N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple + else: + N_out = N + out = torch.empty((M, N_out), dtype=x.dtype, device=x.device) + residual_out = torch.empty_like(residual) + return out, residual_out + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -1108,6 +1143,13 @@ class rocm_aiter_ops: fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_triton_add_rmsnorm_pad", + op_func=_rocm_aiter_triton_add_rmsnorm_pad_impl, + fake_impl=_rocm_aiter_triton_add_rmsnorm_pad_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_group_fp8_quant", op_func=_rocm_aiter_group_fp8_quant_impl, @@ -1175,6 +1217,10 @@ class rocm_aiter_ops: def get_act_mul_fused_fp8_group_quant_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + @staticmethod + def get_triton_add_rmsnorm_pad_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default + @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index a207edd93..e0565ccb2 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -18,8 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass if rocm_aiter_ops.is_enabled(): from vllm.compilation.rocm_aiter_fusion import ( - RocmAiterRMSNormFusionPass, + RocmAiterRMSNormQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, + RocmAiterTritonAddRMSNormPadFusionPass, ) if current_platform.is_cuda_alike(): @@ -123,13 +124,16 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] self.passes += [RMSNormQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): self.passes += [ - RocmAiterRMSNormFusionPass(config), + RocmAiterRMSNormQuantFusionPass(config), ] if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] if rocm_aiter_ops.is_enabled(): self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] + if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)] + if self.pass_config.fuse_attn_quant: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py index 7a300cf50..bfbb2b783 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -266,7 +266,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): ) -class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass): +class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses aiter rms_norm & vllm/aiter quant custom ops into a fused rms_norm_quant op. @@ -399,3 +399,106 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): AiterSiluMulFp8GroupQuantPattern, ] return VllmInductorPass.hash_source(self, *fusion_patterns) + + +class AddAiterRMSNormPadPattern: + """ + This pattern replaces an aiter_rmsnorm_with_add & a pad op + with a custom triton_add_rmsnorm_pad op from AITER. + """ + + AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op() + + def __init__( + self, + epsilon: float, + hidden_size: int, + x_pad_to_multiple: int, + ): + self.epsilon = epsilon + self.hidden_size = hidden_size + self.x_pad_to_multiple = x_pad_to_multiple + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True) + + def get_inputs(self) -> list[torch.Tensor]: + input, weight, residual = self.rmsnorm_matcher.inputs() + router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device) + router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device) + return [input, weight, residual, router_weight, router_bias] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + router_weight: torch.Tensor, + router_bias: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pad_size = self.x_pad_to_multiple - ( + self.hidden_size % self.x_pad_to_multiple + ) + result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual) + router_logits = torch.ops.vllm.rocm_unquantized_gemm( + result_rms, router_weight, router_bias + ) + result = torch.nn.functional.pad( + result_rms, (0, pad_size), mode="constant", value=0.0 + ) + return result, residual_out, router_logits + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + router_weight: torch.Tensor, + router_bias: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + residual=residual, + x_pad_to_multiple=self.x_pad_to_multiple, + ) + result_padded = at[0] + router_logits = torch.ops.vllm.rocm_unquantized_gemm( + result_padded[:, : self.hidden_size], router_weight, router_bias + ) + residual_out = at[1] + return result_padded, residual_out, router_logits + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass): + """ + This pass replaces an AITER CK RMSNorm + residual add and a pad op + with an triton_add_rmsnorm_pad op from AITER. + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass" + ) + + # gpt-oss has hidden size 2880 + # padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively + hidden_size = 2880 + for epsilon in [1e-5, 1e-6]: + for x_pad_to_multiple in [128, 256]: + AddAiterRMSNormPadPattern( + epsilon, hidden_size, x_pad_to_multiple + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> str: + return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 1a3c452fc..fa711ac44 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -126,6 +126,10 @@ class PassConfig: fuse_allreduce_rms: bool = Field(default=None) """Enable flashinfer allreduce fusion.""" + # ROCm/AITER specific fusions + fuse_act_padding: bool = Field(default=None) + """Fuse the custom RMSNorm + padding ops.""" + fi_allreduce_fusion_max_size_mb: float | None = None """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a @@ -194,6 +198,7 @@ class PassConfig: "enable_sp", "fuse_gemm_comms", "fuse_allreduce_rms", + "fuse_act_padding", mode="wrap", ) @classmethod @@ -222,12 +227,23 @@ class PassConfig: "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" ) + if self.fuse_act_padding: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm + padding fusion might not work" + ) if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike(): logger.warning_once( "QK Norm + RoPE fusion enabled but the current platform is not " "CUDA or ROCm. The fusion will be disabled." ) self.enable_qk_norm_rope_fusion = False + if self.fuse_act_padding and not current_platform.is_rocm(): + logger.warning_once( + "Padding fusion enabled but the current platform is not ROCm. " + "The fusion will be disabled." + ) + self.fuse_act_padding = False class DynamicShapesType(str, enum.Enum): diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 1f8f5e5db..3e5a4b8f5 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -102,6 +102,18 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool: ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") +def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: + """Enable if using AITER RMSNorm and AITER Triton GEMMs + and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion.""" + + return ( + envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_RMSNORM + and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + and cfg.model_config.get_hidden_size() == 2880 + ) + + OPTIMIZATION_LEVEL_00 = { "compilation_config": { "pass_config": { @@ -112,6 +124,7 @@ OPTIMIZATION_LEVEL_00 = { "fuse_attn_quant": False, "enable_sp": False, "fuse_gemm_comms": False, + "fuse_act_padding": False, }, "cudagraph_mode": CUDAGraphMode.NONE, "use_inductor_graph_partition": False, @@ -127,6 +140,7 @@ OPTIMIZATION_LEVEL_01 = { "fuse_attn_quant": False, "enable_sp": False, "fuse_gemm_comms": False, + "fuse_act_padding": enable_norm_pad_fusion, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, @@ -142,6 +156,7 @@ OPTIMIZATION_LEVEL_02 = { "fuse_attn_quant": IS_QUANTIZED, "enable_sp": IS_DENSE, "fuse_gemm_comms": IS_DENSE, + "fuse_act_padding": enable_norm_pad_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -157,6 +172,7 @@ OPTIMIZATION_LEVEL_03 = { "fuse_attn_quant": IS_QUANTIZED, "enable_sp": IS_DENSE, "fuse_gemm_comms": IS_DENSE, + "fuse_act_padding": enable_norm_pad_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 364204bec..2ec364213 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -137,6 +137,11 @@ def rocm_unquantized_gemm_impl( import math + if use_aiter_triton_gemm(n, m, k, x.dtype): + from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + + return gemm_a16w16(x, weight, bias) + use_skinny_reduce_counting = ( envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx950() @@ -155,11 +160,6 @@ def rocm_unquantized_gemm_impl( out = ops.wvSplitKrc(weight, x_view, cu_count, bias) return out.reshape(*x.shape[:-1], weight.shape[0]) - if use_aiter_triton_gemm(n, m, k, x.dtype): - from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 - - return gemm_a16w16(x, weight, bias) - use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index b273880ce..f62771c36 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -187,7 +187,7 @@ class MLPBlock(torch.nn.Module): ) else: g = self.router(x) - x = self.experts(hidden_states=x, router_logits=g) + x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0)