Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-01-28 14:47:47 -06:00
committed by GitHub
parent 3e440786af
commit 59bcc5b6f2
9 changed files with 327 additions and 11 deletions

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import vllm.config
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(
self,
num_layers: int,
hidden_size: int,
num_local_experts: int,
x_pad_to_multiple: int,
):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.pad_dim = x_pad_to_multiple - (hidden_size % x_pad_to_multiple)
self.norm = [RMSNorm(hidden_size, eps=1e-5) for _ in range(num_layers)]
self.router = [
torch.nn.Linear(hidden_size, num_local_experts) for _ in range(4)
]
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
all_router_logits = []
for layer in range(self.num_layers):
x = x[:, : self.hidden_size]
x, resid = self.norm[layer](x, resid)
router_logits = rocm_unquantized_gemm(
self, x, self.router[layer].weight, self.router[layer].bias
)
x = torch.nn.functional.pad(
x, (0, self.pad_dim), mode="constant", value=0.0
)
all_router_logits.append(router_logits)
return x, resid, *all_router_logits
def ops_in_model_before(self):
return [
rocm_aiter_ops.get_rmsnorm_fused_add_op(),
torch.ops.aten.constant_pad_nd,
]
def ops_in_model_after(self):
return [rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_layers", [3])
@pytest.mark.parametrize("hidden_size", [2880])
@pytest.mark.parametrize("num_local_experts", [128])
@pytest.mark.parametrize("x_pad_to_multiple", [256])
@pytest.mark.skipif(
not is_aiter_found_and_supported(),
reason="Only test on ROCm with AITER installed and supported",
)
def test_fuse_act_padding(
dtype: torch.dtype,
num_layers: int,
hidden_size: int,
num_local_experts: int,
x_pad_to_multiple: int,
monkeypatch: pytest.MonkeyPatch,
):
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(fuse_act_padding=True, eliminate_noops=True),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterTritonAddRMSNormPadFusionPass,
)
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
fusion_pass = RocmAiterTritonAddRMSNormPadFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
fusion_pass,
PostCleanupPass(vllm_config),
]
backend = TestBackend(*passes)
model = TestModel(num_layers, hidden_size, num_local_experts, x_pad_to_multiple)
x = torch.rand(1, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
outputs_unfused = model(x)
model_fused = torch.compile(model, backend=backend)
outputs_fused = model_fused(x)
torch.testing.assert_close(outputs_unfused, outputs_fused)
assert fusion_pass.matched_count == num_layers
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())

View File

@@ -410,7 +410,7 @@ def test_aiter_fusion_rmsnorm_quant(
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass
m.setenv("VLLM_ROCM_USE_AITER", "1")
@@ -420,7 +420,7 @@ def test_aiter_fusion_rmsnorm_quant(
torch.set_default_dtype(dtype)
torch.manual_seed(1)
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)
model = TestModel(
hidden_size=hidden_size,

View File

@@ -790,6 +790,41 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
return x_fp8, out_bs
def _rocm_aiter_triton_add_rmsnorm_pad_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
residual: torch.Tensor,
x_pad_to_multiple: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
return fused_add_rmsnorm_pad(
x,
weight,
variance_epsilon,
residual,
x_pad_to_multiple=x_pad_to_multiple,
)
def _rocm_aiter_triton_add_rmsnorm_pad_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
residual: torch.Tensor,
x_pad_to_multiple: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
if x_pad_to_multiple > 0:
N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple
else:
N_out = N
out = torch.empty((M, N_out), dtype=x.dtype, device=x.device)
residual_out = torch.empty_like(residual)
return out, residual_out
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@@ -1108,6 +1143,13 @@ class rocm_aiter_ops:
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_triton_add_rmsnorm_pad",
op_func=_rocm_aiter_triton_add_rmsnorm_pad_impl,
fake_impl=_rocm_aiter_triton_add_rmsnorm_pad_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
@@ -1175,6 +1217,10 @@ class rocm_aiter_ops:
def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
@staticmethod
def get_triton_add_rmsnorm_pad_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float

View File

@@ -18,8 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFusionPass,
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)
if current_platform.is_cuda_alike():
@@ -123,13 +124,16 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [
RocmAiterRMSNormFusionPass(config),
RocmAiterRMSNormQuantFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]

View File

@@ -266,7 +266,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
)
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
@@ -399,3 +399,106 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
class AddAiterRMSNormPadPattern:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
def __init__(
self,
epsilon: float,
hidden_size: int,
x_pad_to_multiple: int,
):
self.epsilon = epsilon
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
def get_inputs(self) -> list[torch.Tensor]:
input, weight, residual = self.rmsnorm_matcher.inputs()
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
return [input, weight, residual, router_weight, router_bias]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pad_size = self.x_pad_to_multiple - (
self.hidden_size % self.x_pad_to_multiple
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_rms, router_weight, router_bias
)
result = torch.nn.functional.pad(
result_rms, (0, pad_size), mode="constant", value=0.0
)
return result, residual_out, router_logits
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
residual=residual,
x_pad_to_multiple=self.x_pad_to_multiple,
)
result_padded = at[0]
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_padded[:, : self.hidden_size], router_weight, router_bias
)
residual_out = at[1]
return result_padded, residual_out, router_logits
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)
# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size = 2880
for epsilon in [1e-5, 1e-6]:
for x_pad_to_multiple in [128, 256]:
AddAiterRMSNormPadPattern(
epsilon, hidden_size, x_pad_to_multiple
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)

View File

@@ -126,6 +126,10 @@ class PassConfig:
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
"""Fuse the custom RMSNorm + padding ops."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
@@ -194,6 +198,7 @@ class PassConfig:
"enable_sp",
"fuse_gemm_comms",
"fuse_allreduce_rms",
"fuse_act_padding",
mode="wrap",
)
@classmethod
@@ -222,12 +227,23 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.fuse_act_padding:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + padding fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA or ROCm. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False
if self.fuse_act_padding and not current_platform.is_rocm():
logger.warning_once(
"Padding fusion enabled but the current platform is not ROCm. "
"The fusion will be disabled."
)
self.fuse_act_padding = False
class DynamicShapesType(str, enum.Enum):

View File

@@ -102,6 +102,18 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
return (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
and cfg.model_config.get_hidden_size() == 2880
)
OPTIMIZATION_LEVEL_00 = {
"compilation_config": {
"pass_config": {
@@ -112,6 +124,7 @@ OPTIMIZATION_LEVEL_00 = {
"fuse_attn_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
@@ -127,6 +140,7 @@ OPTIMIZATION_LEVEL_01 = {
"fuse_attn_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
@@ -142,6 +156,7 @@ OPTIMIZATION_LEVEL_02 = {
"fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
@@ -157,6 +172,7 @@ OPTIMIZATION_LEVEL_03 = {
"fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,

View File

@@ -137,6 +137,11 @@ def rocm_unquantized_gemm_impl(
import math
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx950()
@@ -155,11 +160,6 @@ def rocm_unquantized_gemm_impl(
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()

View File

@@ -187,7 +187,7 @@ class MLPBlock(torch.nn.Module):
)
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)