Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-01-28 14:47:47 -06:00
committed by GitHub
parent 3e440786af
commit 59bcc5b6f2
9 changed files with 327 additions and 11 deletions

View File

@@ -126,6 +126,10 @@ class PassConfig:
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
"""Fuse the custom RMSNorm + padding ops."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
@@ -194,6 +198,7 @@ class PassConfig:
"enable_sp",
"fuse_gemm_comms",
"fuse_allreduce_rms",
"fuse_act_padding",
mode="wrap",
)
@classmethod
@@ -222,12 +227,23 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.fuse_act_padding:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + padding fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda_alike():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA or ROCm. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False
if self.fuse_act_padding and not current_platform.is_rocm():
logger.warning_once(
"Padding fusion enabled but the current platform is not ROCm. "
"The fusion will be disabled."
)
self.fuse_act_padding = False
class DynamicShapesType(str, enum.Enum):