[Bugfix]: Fix ROCm fusion attn test; use AttentionBackend utils to create kv cache (#33948)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -92,6 +92,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# TODO (Rohan138) reuse utils from vllm/v1/worker/gpu/attn_utils.py
|
||||
|
||||
# Create common attn metadata
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
@@ -100,58 +102,31 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
backend = self.attn.backend
|
||||
|
||||
# TODO(luka) use get_kv_cache_stride_order
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
2,
|
||||
num_blocks,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
2,
|
||||
num_blocks,
|
||||
self.block_size,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
# k/v as 2nd dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == AttentionBackendEnum.FLASHINFER:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
).permute(0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
# Fetch the attention backend and kv cache shape and stride order
|
||||
attn_backend = self.attn.attn_backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size
|
||||
)
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
# Create dummy KV cache
|
||||
raw_tensor = torch.zeros(
|
||||
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_cache = raw_tensor.permute(*inv_order)
|
||||
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
|
||||
# Build attn metadata
|
||||
|
||||
Reference in New Issue
Block a user