Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-01-28 14:47:47 -06:00
committed by GitHub
parent 3e440786af
commit 59bcc5b6f2
9 changed files with 327 additions and 11 deletions

View File

@@ -18,8 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass
if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFusionPass,
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)
if current_platform.is_cuda_alike():
@@ -123,13 +124,16 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [
RocmAiterRMSNormFusionPass(config),
RocmAiterRMSNormQuantFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]

View File

@@ -266,7 +266,7 @@ class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
)
class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
@@ -399,3 +399,106 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)
class AddAiterRMSNormPadPattern:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
def __init__(
self,
epsilon: float,
hidden_size: int,
x_pad_to_multiple: int,
):
self.epsilon = epsilon
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
def get_inputs(self) -> list[torch.Tensor]:
input, weight, residual = self.rmsnorm_matcher.inputs()
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
return [input, weight, residual, router_weight, router_bias]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pad_size = self.x_pad_to_multiple - (
self.hidden_size % self.x_pad_to_multiple
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_rms, router_weight, router_bias
)
result = torch.nn.functional.pad(
result_rms, (0, pad_size), mode="constant", value=0.0
)
return result, residual_out, router_logits
def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
residual=residual,
x_pad_to_multiple=self.x_pad_to_multiple,
)
result_padded = at[0]
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_padded[:, : self.hidden_size], router_weight, router_bias
)
residual_out = at[1]
return result_padded, residual_out, router_logits
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)
# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size = 2880
for epsilon in [1e-5, 1e-6]:
for x_pad_to_multiple in [128, 256]:
AddAiterRMSNormPadPattern(
epsilon, hidden_size, x_pad_to_multiple
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)