[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
|
||||
)
|
||||
|
||||
|
||||
def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if rotary embedding custom op is active and
|
||||
use_inductor_graph_partition is enabled.
|
||||
"""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
return (
|
||||
rocm_aiter_ops.is_enabled()
|
||||
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
|
||||
and cfg.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
|
||||
|
||||
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
|
||||
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
return (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
and not rocm_aiter_ops.is_triton_gemm_enabled()
|
||||
and cfg.model_config is not None
|
||||
and cfg.model_config.get_hidden_size() == 2880
|
||||
)
|
||||
@@ -149,6 +162,7 @@ OPTIMIZATION_LEVEL_00 = {
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": False,
|
||||
"fuse_rope_kvcache": False,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -167,6 +181,7 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -185,6 +200,7 @@ OPTIMIZATION_LEVEL_02 = {
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -203,6 +219,7 @@ OPTIMIZATION_LEVEL_03 = {
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
|
||||
Reference in New Issue
Block a user