[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-02-24 22:36:40 -06:00
committed by GitHub
parent ec1d30c0f6
commit f38f8c9742
9 changed files with 139 additions and 67 deletions

View File

@@ -177,7 +177,10 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rope_custom_op:
ops.append(ROTARY_OP)
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
else:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
@@ -196,6 +199,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
],
)
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
@pytest.mark.parametrize("num_heads", [64])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("head_size", [64])
@@ -210,6 +214,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
def test_rope_kvcache_fusion(
attn_backend: AttentionBackendEnum,
enable_rope_custom_op: bool,
enable_aiter_triton_rope: bool,
num_heads: int,
num_kv_heads: int,
head_size: int,
@@ -245,6 +250,9 @@ def test_rope_kvcache_fusion(
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1")
m.setenv(
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
)
rocm_aiter_ops.refresh_env_variables()
model = QKRoPEKVCacheTestModel(

View File

@@ -831,6 +831,59 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
return out, residual_out
def _triton_rotary_embedding_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
offsets: torch.Tensor | None = None,
) -> None:
# Modifies query and key in-place
from aiter.ops.triton.rope.rope import (
rope_cached_thd_positions_offsets_2c_fwd_inplace,
)
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox else 1
rotary_dim = head_size
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_offsets_2c_fwd_inplace(
query_,
key_,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)
def _triton_rotary_embedding_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
offsets: torch.Tensor | None = None,
) -> None:
return
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@@ -1178,6 +1231,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
# Register rocm aiter rotary embedding custom op
direct_register_custom_op(
op_name="rocm_aiter_triton_rotary_embedding",
op_func=_triton_rotary_embedding_impl,
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_triton_rotary_embedding_fake,
)
_OPS_REGISTERED = True
@staticmethod
@@ -1220,6 +1281,10 @@ class rocm_aiter_ops:
def get_triton_add_rmsnorm_pad_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default
@staticmethod
def get_triton_rotary_embedding_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
@@ -1482,42 +1547,6 @@ class rocm_aiter_ops:
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
@staticmethod
def triton_rotary_embed(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_2c_fwd_inplace(
query_,
key_,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)
@staticmethod
def triton_rope_and_cache(
query: torch.Tensor,

View File

@@ -89,10 +89,13 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
match_rocm_aiter: bool | None = None,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
if match_rocm_aiter is None:
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
super().__init__(enabled)
self.is_neox = is_neox
@@ -104,6 +107,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
elif match_rocm_aiter:
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
else:
self.rotary_op = ROTARY_OP

View File

@@ -60,6 +60,10 @@ class ScatterSplitReplacementPass(VllmInductorPass):
def __call__(self, graph: fx.Graph) -> None:
count = 0
target_ops = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue
@@ -67,7 +71,7 @@ class ScatterSplitReplacementPass(VllmInductorPass):
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
if at_target in target_ops:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}

View File

@@ -123,6 +123,8 @@ class PassConfig:
"""Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
@@ -153,8 +155,6 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@@ -834,23 +834,20 @@ class CompilationConfig:
func if isinstance(func, InductorPass) else CallableInductorPass(func)
)
if self.pass_config.enable_qk_norm_rope_fusion:
if (
self.pass_config.enable_qk_norm_rope_fusion
and "+rotary_embedding" not in self.custom_ops
):
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if self.pass_config.fuse_rope_kvcache:
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
logger.warning(
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
)
self.pass_config.fuse_rope_kvcache = False
else:
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
self.pass_config.fuse_rope_kvcache
and "+rotary_embedding" not in self.custom_ops
):
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")

View File

@@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
)
def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
"""Enable if rotary embedding custom op is active and
use_inductor_graph_partition is enabled.
"""
from vllm._aiter_ops import rocm_aiter_ops
return (
rocm_aiter_ops.is_enabled()
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
and cfg.compilation_config.use_inductor_graph_partition
)
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
from vllm._aiter_ops import rocm_aiter_ops
return (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
rocm_aiter_ops.is_rmsnorm_enabled()
and not rocm_aiter_ops.is_triton_gemm_enabled()
and cfg.model_config is not None
and cfg.model_config.get_hidden_size() == 2880
)
@@ -149,6 +162,7 @@ OPTIMIZATION_LEVEL_00 = {
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": False,
"fuse_rope_kvcache": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
@@ -167,6 +181,7 @@ OPTIMIZATION_LEVEL_01 = {
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
@@ -185,6 +200,7 @@ OPTIMIZATION_LEVEL_02 = {
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
@@ -203,6 +219,7 @@ OPTIMIZATION_LEVEL_03 = {
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,

View File

@@ -105,7 +105,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
@@ -937,9 +937,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1")
),
# Whether to use aiter rope.
# By default is disabled.
# By default is enabled.
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1")
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.

View File

@@ -47,15 +47,20 @@ class RotaryEmbeddingBase(CustomOp):
if not hasattr(self, "use_flashinfer"):
self.use_flashinfer = False
self.use_aiter = (
self.enabled() and rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
if self.use_aiter:
self.rocm_aiter_triton_rotary_embedding = (
rocm_aiter_ops.get_triton_rotary_embedding_op()
)
if init_cache:
cache = self._compute_cos_sin_cache()
if not self.use_flashinfer:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embed_enabled = (
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
self.apply_rotary_emb = ApplyRotaryEmb(
is_neox_style=self.is_neox_style,
@@ -231,15 +236,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embed_enabled:
if self.use_aiter:
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
rocm_aiter_ops.triton_rotary_embed(
self.rocm_aiter_triton_rotary_embedding(
positions,
query,
key,
cos_sin_cache,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox_style,
)
return query, key

View File

@@ -494,6 +494,7 @@ class RocmPlatform(Platform):
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
@@ -558,6 +559,13 @@ class RocmPlatform(Platform):
and "-grouped_topk" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding when using AITER if its not disabled by user
if (
use_aiter_triton_rope
and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rotary_embedding")
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer")