[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -177,7 +177,10 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rope_custom_op:
|
||||
ops.append(ROTARY_OP)
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
|
||||
@@ -196,6 +199,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
|
||||
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
|
||||
@pytest.mark.parametrize("num_heads", [64])
|
||||
@pytest.mark.parametrize("num_kv_heads", [8])
|
||||
@pytest.mark.parametrize("head_size", [64])
|
||||
@@ -210,6 +214,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
def test_rope_kvcache_fusion(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
enable_rope_custom_op: bool,
|
||||
enable_aiter_triton_rope: bool,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
@@ -245,6 +250,9 @@ def test_rope_kvcache_fusion(
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
m.setenv(
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
|
||||
)
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
model = QKRoPEKVCacheTestModel(
|
||||
|
||||
@@ -831,6 +831,59 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
|
||||
return out, residual_out
|
||||
|
||||
|
||||
def _triton_rotary_embedding_impl(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# Modifies query and key in-place
|
||||
from aiter.ops.triton.rope.rope import (
|
||||
rope_cached_thd_positions_offsets_2c_fwd_inplace,
|
||||
)
|
||||
|
||||
num_tokens = positions.numel()
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
rotate_style = 0 if is_neox else 1
|
||||
rotary_dim = head_size
|
||||
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
query_ = query[..., :rotary_dim]
|
||||
key_ = key[..., :rotary_dim]
|
||||
positions = positions.view(*query.shape[:1])
|
||||
rope_cached_thd_positions_offsets_2c_fwd_inplace(
|
||||
query_,
|
||||
key_,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
offsets,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=False,
|
||||
)
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
|
||||
|
||||
def _triton_rotary_embedding_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
@@ -1178,6 +1231,14 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
# Register rocm aiter rotary embedding custom op
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_triton_rotary_embedding",
|
||||
op_func=_triton_rotary_embedding_impl,
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_triton_rotary_embedding_fake,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
@@ -1220,6 +1281,10 @@ class rocm_aiter_ops:
|
||||
def get_triton_add_rmsnorm_pad_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default
|
||||
|
||||
@staticmethod
|
||||
def get_triton_rotary_embedding_op() -> OpOverload:
|
||||
return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
@@ -1482,42 +1547,6 @@ class rocm_aiter_ops:
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def triton_rotary_embed(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
):
|
||||
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
|
||||
|
||||
num_tokens = positions.numel()
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
rotate_style = 0 if is_neox_style else 1
|
||||
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
query_ = query[..., :rotary_dim]
|
||||
key_ = key[..., :rotary_dim]
|
||||
positions = positions.view(*query.shape[:1])
|
||||
rope_cached_thd_positions_2c_fwd_inplace(
|
||||
query_,
|
||||
key_,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=False,
|
||||
)
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
|
||||
@staticmethod
|
||||
def triton_rope_and_cache(
|
||||
query: torch.Tensor,
|
||||
|
||||
@@ -89,10 +89,13 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_flashinfer: bool = False,
|
||||
match_rocm_aiter: bool | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RotaryEmbedding.enabled()
|
||||
if match_rocm_aiter is None:
|
||||
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.is_neox = is_neox
|
||||
@@ -104,6 +107,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
self.rotary_dim = head_size
|
||||
if use_flashinfer:
|
||||
self.rotary_op = FLASHINFER_ROTARY_OP
|
||||
elif match_rocm_aiter:
|
||||
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
|
||||
else:
|
||||
self.rotary_op = ROTARY_OP
|
||||
|
||||
|
||||
@@ -60,6 +60,10 @@ class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
target_ops = [torch.ops._C.rotary_embedding.default]
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue
|
||||
@@ -67,7 +71,7 @@ class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
if at_target in target_ops:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = {}
|
||||
|
||||
@@ -123,6 +123,8 @@ class PassConfig:
|
||||
"""Enable async TP."""
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# ROCm/AITER specific fusions
|
||||
fuse_act_padding: bool = Field(default=None)
|
||||
@@ -153,8 +155,6 @@ class PassConfig:
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@@ -834,23 +834,20 @@ class CompilationConfig:
|
||||
func if isinstance(func, InductorPass) else CallableInductorPass(func)
|
||||
)
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
if (
|
||||
self.pass_config.enable_qk_norm_rope_fusion
|
||||
and "+rotary_embedding" not in self.custom_ops
|
||||
):
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
logger.warning(
|
||||
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
|
||||
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
|
||||
)
|
||||
self.pass_config.fuse_rope_kvcache = False
|
||||
else:
|
||||
# TODO(Rohan138): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
if (
|
||||
self.pass_config.fuse_rope_kvcache
|
||||
and "+rotary_embedding" not in self.custom_ops
|
||||
):
|
||||
# TODO(Rohan138): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
|
||||
@@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
|
||||
)
|
||||
|
||||
|
||||
def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if rotary embedding custom op is active and
|
||||
use_inductor_graph_partition is enabled.
|
||||
"""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
return (
|
||||
rocm_aiter_ops.is_enabled()
|
||||
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
|
||||
and cfg.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
|
||||
|
||||
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
|
||||
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
return (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
and not rocm_aiter_ops.is_triton_gemm_enabled()
|
||||
and cfg.model_config is not None
|
||||
and cfg.model_config.get_hidden_size() == 2880
|
||||
)
|
||||
@@ -149,6 +162,7 @@ OPTIMIZATION_LEVEL_00 = {
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": False,
|
||||
"fuse_rope_kvcache": False,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -167,6 +181,7 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -185,6 +200,7 @@ OPTIMIZATION_LEVEL_02 = {
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
@@ -203,6 +219,7 @@ OPTIMIZATION_LEVEL_03 = {
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
|
||||
@@ -105,7 +105,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
|
||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
|
||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||
@@ -937,9 +937,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter rope.
|
||||
# By default is disabled.
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
|
||||
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter triton fp8 bmm kernel
|
||||
# By default is enabled.
|
||||
|
||||
@@ -47,15 +47,20 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
if not hasattr(self, "use_flashinfer"):
|
||||
self.use_flashinfer = False
|
||||
|
||||
self.use_aiter = (
|
||||
self.enabled() and rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
)
|
||||
if self.use_aiter:
|
||||
self.rocm_aiter_triton_rotary_embedding = (
|
||||
rocm_aiter_ops.get_triton_rotary_embedding_op()
|
||||
)
|
||||
|
||||
if init_cache:
|
||||
cache = self._compute_cos_sin_cache()
|
||||
if not self.use_flashinfer:
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.is_rocm_triton_rotary_embed_enabled = (
|
||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
is_neox_style=self.is_neox_style,
|
||||
@@ -231,15 +236,14 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.is_rocm_triton_rotary_embed_enabled:
|
||||
if self.use_aiter:
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
rocm_aiter_ops.triton_rotary_embed(
|
||||
self.rocm_aiter_triton_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
cos_sin_cache,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
|
||||
@@ -494,6 +494,7 @@ class RocmPlatform(Platform):
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
@@ -558,6 +559,13 @@ class RocmPlatform(Platform):
|
||||
and "-grouped_topk" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+grouped_topk")
|
||||
# Enable rotary embedding when using AITER if its not disabled by user
|
||||
if (
|
||||
use_aiter_triton_rope
|
||||
and "+rotary_embedding" not in compilation_config.custom_ops
|
||||
and "-rotary_embedding" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rotary_embedding")
|
||||
|
||||
# Default dispatch to rocm's sparse_attn_indexer implementation
|
||||
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
||||
|
||||
Reference in New Issue
Block a user