335 lines
12 KiB
Python
335 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import vllm.config
|
|
from tests.compile.backend import TestBackend
|
|
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
|
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
|
from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP
|
|
from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
|
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
|
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
|
from vllm.compilation.passes.utility.scatter_split_replace import (
|
|
ScatterSplitReplacementPass,
|
|
)
|
|
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
|
from vllm.config import (
|
|
CacheConfig,
|
|
CompilationConfig,
|
|
CompilationMode,
|
|
ModelConfig,
|
|
PassConfig,
|
|
VllmConfig,
|
|
)
|
|
from vllm.forward_context import get_forward_context, set_forward_context
|
|
from vllm.model_executor.layers.attention import Attention
|
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|
from vllm.platforms import current_platform
|
|
from vllm.v1.attention.backend import (
|
|
AttentionBackend,
|
|
CommonAttentionMetadata,
|
|
)
|
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
|
|
VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
|
|
|
|
class QKRoPEKVCacheTestModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
attn_backend: AttentionBackendEnum,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
is_neox: bool,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
prefix: str = "model.layers.0.self_attn.attn",
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.num_kv_heads = num_kv_heads
|
|
self.head_size = head_size
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.q_size = num_heads * head_size
|
|
self.kv_size = num_kv_heads * head_size
|
|
self.is_neox = is_neox
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.layer_name = prefix
|
|
|
|
self.rotary_emb = RotaryEmbedding(
|
|
head_size,
|
|
rotary_dim=head_size,
|
|
max_position_embeddings=4096,
|
|
base=10000,
|
|
is_neox_style=is_neox,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
# Whether to check for the RoPE custom op or component index_select
|
|
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
|
|
|
# Register layer metadata for the fusion pass via Attention.
|
|
self.attn = Attention(
|
|
num_heads=num_heads,
|
|
head_size=head_size,
|
|
scale=1.0 / head_size**0.5,
|
|
num_kv_heads=num_kv_heads,
|
|
cache_config=vllm_config.cache_config,
|
|
quant_config=vllm_config.quant_config,
|
|
prefix=prefix,
|
|
attn_backend=attn_backend.get_class(),
|
|
)
|
|
self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend()
|
|
assert not self.attn_backend.forward_includes_kv_cache_update, (
|
|
f"Attention backend {self.attn_backend} does not support fuse_rope_kvcache."
|
|
)
|
|
self.attn._k_scale = self.attn._k_scale.to(device)
|
|
self.attn._v_scale = self.attn._v_scale.to(device)
|
|
|
|
kv_cache_dtype_str = vllm_config.cache_config.cache_dtype
|
|
self.kv_cache_dtype = (
|
|
FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype
|
|
)
|
|
|
|
# Initialize attn MetadataBuilder
|
|
self.builder = self.attn.attn_backend.get_builder_cls()(
|
|
kv_cache_spec=AttentionSpec(
|
|
block_size=self.block_size,
|
|
num_kv_heads=self.num_kv_heads,
|
|
head_size=head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
),
|
|
layer_names=[self.attn.layer_name],
|
|
vllm_config=vllm_config,
|
|
device=device,
|
|
)
|
|
|
|
def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
|
|
"""Initialize attention metadata."""
|
|
# Create common attn metadata
|
|
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
|
common_attn_metadata = create_common_attn_metadata(
|
|
batch_spec, self.block_size, self.device, arange_block_indices=True
|
|
)
|
|
|
|
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
|
num_blocks = batch_size * max_blocks
|
|
|
|
# Fetch the attention backend and kv cache shape and stride order
|
|
attn_backend = self.attn.attn_backend
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks, self.block_size, self.num_kv_heads, self.head_size
|
|
)
|
|
try:
|
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
|
except (AttributeError, NotImplementedError):
|
|
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
|
|
|
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
|
inv_order = [
|
|
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
|
|
]
|
|
|
|
# Create dummy KV cache
|
|
raw_tensor = torch.zeros(
|
|
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device,
|
|
)
|
|
raw_tensor = raw_tensor.view(kv_cache_shape)
|
|
kv_cache = raw_tensor.permute(*inv_order)
|
|
|
|
self.attn.kv_cache = [kv_cache]
|
|
|
|
# Build attn metadata
|
|
attn_metadata = self.builder.build(
|
|
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
|
)
|
|
|
|
return attn_metadata
|
|
|
|
def forward(
|
|
self, qkv: torch.Tensor, positions: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# Create copy so inplace ops do not modify the original tensors
|
|
qkv = qkv.clone()
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
|
|
# Instead of a full forward pass, match only the KV cache update op here
|
|
q = q.view(-1, self.num_heads, self.head_size)
|
|
k = k.view(-1, self.num_kv_heads, self.head_size)
|
|
v = v.view(-1, self.num_kv_heads, self.head_size)
|
|
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
|
k, v, self.layer_name
|
|
)
|
|
return q, k, v, kv_cache_dummy_dep
|
|
|
|
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
|
ops = []
|
|
if self.enable_rope_custom_op:
|
|
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
|
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
|
else:
|
|
ops.append(ROTARY_OP)
|
|
else:
|
|
ops.append(INDEX_SELECT_OP)
|
|
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
|
|
return ops
|
|
|
|
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
|
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"attn_backend",
|
|
[
|
|
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
|
AttentionBackendEnum.TRITON_ATTN,
|
|
AttentionBackendEnum.ROCM_ATTN,
|
|
AttentionBackendEnum.ROCM_AITER_FA,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
|
|
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
|
|
@pytest.mark.parametrize("num_heads", [64])
|
|
@pytest.mark.parametrize("num_kv_heads", [8])
|
|
@pytest.mark.parametrize("head_size", [64])
|
|
@pytest.mark.parametrize("block_size", [16])
|
|
@pytest.mark.parametrize("is_neox", [True, False])
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
|
@pytest.mark.skipif(
|
|
not is_aiter_found_and_supported(),
|
|
reason="Only test on ROCm with AITER installed and supported",
|
|
)
|
|
def test_rope_kvcache_fusion(
|
|
attn_backend: AttentionBackendEnum,
|
|
enable_rope_custom_op: bool,
|
|
enable_aiter_triton_rope: bool,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
block_size: int,
|
|
is_neox: bool,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: str,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
):
|
|
torch.set_default_device("cuda")
|
|
torch.set_default_dtype(dtype)
|
|
torch.manual_seed(0)
|
|
|
|
custom_ops: list[str] = []
|
|
if enable_rope_custom_op:
|
|
custom_ops.append("+rotary_embedding")
|
|
|
|
vllm_config = VllmConfig(
|
|
model_config=ModelConfig(dtype=dtype),
|
|
cache_config=CacheConfig(
|
|
block_size=block_size,
|
|
cache_dtype=kv_cache_dtype,
|
|
),
|
|
compilation_config=CompilationConfig(
|
|
mode=CompilationMode.VLLM_COMPILE,
|
|
custom_ops=custom_ops,
|
|
pass_config=PassConfig(
|
|
fuse_rope_kvcache=True,
|
|
eliminate_noops=True,
|
|
),
|
|
),
|
|
)
|
|
|
|
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
|
m.setenv(
|
|
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
|
|
)
|
|
rocm_aiter_ops.refresh_env_variables()
|
|
|
|
model = QKRoPEKVCacheTestModel(
|
|
vllm_config=vllm_config,
|
|
attn_backend=attn_backend,
|
|
num_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_size=head_size,
|
|
is_neox=is_neox,
|
|
dtype=dtype,
|
|
device=torch.get_default_device(),
|
|
)
|
|
|
|
fusion_pass = RopeKVCacheFusionPass(vllm_config)
|
|
passes = [
|
|
NoOpEliminationPass(vllm_config),
|
|
SplitCoalescingPass(vllm_config),
|
|
ScatterSplitReplacementPass(vllm_config),
|
|
fusion_pass,
|
|
PostCleanupPass(vllm_config),
|
|
]
|
|
backend = TestBackend(*passes)
|
|
|
|
T = 5
|
|
|
|
qkv = torch.randn(
|
|
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
|
|
)
|
|
pos = torch.arange(T, dtype=torch.long)
|
|
|
|
qkv_unfused = qkv.clone()
|
|
pos_unfused = pos.clone()
|
|
|
|
with set_forward_context(None, vllm_config):
|
|
forward_context = get_forward_context()
|
|
attn_metadata = model.build_attn_metadata(T)
|
|
forward_context.slot_mapping = {
|
|
model.layer_name: attn_metadata.slot_mapping
|
|
}
|
|
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
|
|
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
|
kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine]
|
|
del dummy
|
|
|
|
torch._dynamo.mark_dynamic(qkv, 0)
|
|
torch._dynamo.mark_dynamic(pos, 0)
|
|
with set_forward_context(None, vllm_config):
|
|
model_fused = torch.compile(model, backend=backend)
|
|
forward_context = get_forward_context()
|
|
attn_metadata = model_fused.build_attn_metadata(T)
|
|
forward_context.slot_mapping = {
|
|
model.layer_name: attn_metadata.slot_mapping
|
|
}
|
|
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
|
|
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
|
kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine]
|
|
del dummy
|
|
|
|
assert fusion_pass.matched_count == 1
|
|
|
|
backend.check_before_ops(model.ops_in_model_before())
|
|
backend.check_after_ops(model.ops_in_model_after())
|
|
|
|
if dtype == torch.float16:
|
|
ATOL, RTOL = (2e-3, 2e-3)
|
|
else:
|
|
ATOL, RTOL = (1e-2, 1e-2)
|
|
|
|
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
|
|
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
|
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
|
# Cannot compare fp8_* directly here, cast to model dtype instead
|
|
torch.testing.assert_close(
|
|
kv_cache_unfused.view(dtype),
|
|
kv_cache_fused.view(dtype),
|
|
atol=ATOL,
|
|
rtol=RTOL,
|
|
)
|