Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-01-28 14:47:47 -06:00
committed by GitHub
parent 3e440786af
commit 59bcc5b6f2
9 changed files with 327 additions and 11 deletions

View File

@@ -102,6 +102,18 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
return (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
and cfg.model_config.get_hidden_size() == 2880
)
OPTIMIZATION_LEVEL_00 = {
"compilation_config": {
"pass_config": {
@@ -112,6 +124,7 @@ OPTIMIZATION_LEVEL_00 = {
"fuse_attn_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
@@ -127,6 +140,7 @@ OPTIMIZATION_LEVEL_01 = {
"fuse_attn_quant": False,
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
@@ -142,6 +156,7 @@ OPTIMIZATION_LEVEL_02 = {
"fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
@@ -157,6 +172,7 @@ OPTIMIZATION_LEVEL_03 = {
"fuse_attn_quant": IS_QUANTIZED,
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,