[V1] Allocate kv_cache with stride order for V1 (#18775)

Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-05-29 19:54:16 +02:00
committed by GitHub
parent d58f9c7f7a
commit 32ce3cf7c9
2 changed files with 81 additions and 16 deletions

View File

@@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
try:
kv_cache_stride_order = self.attn_backends[
i].get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = torch.zeros(
kv_cache_shape, dtype=dtype,
device=self.device).permute(*inv_order)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.