[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -177,7 +177,10 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rope_custom_op:
|
||||
ops.append(ROTARY_OP)
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
|
||||
@@ -196,6 +199,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
|
||||
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
|
||||
@pytest.mark.parametrize("num_heads", [64])
|
||||
@pytest.mark.parametrize("num_kv_heads", [8])
|
||||
@pytest.mark.parametrize("head_size", [64])
|
||||
@@ -210,6 +214,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
def test_rope_kvcache_fusion(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
enable_rope_custom_op: bool,
|
||||
enable_aiter_triton_rope: bool,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
@@ -245,6 +250,9 @@ def test_rope_kvcache_fusion(
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
m.setenv(
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
|
||||
)
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
model = QKRoPEKVCacheTestModel(
|
||||
|
||||
Reference in New Issue
Block a user