[V1] Allocate kv_cache with stride order for V1 (#18775)
Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
@@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backends[
|
||||
i].get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(
|
||||
range(len(kv_cache_shape)))
|
||||
# The allocation respects the backend-defined stride order
|
||||
# to ensure the semantic remains consistent for each
|
||||
# backend. We first obtain the generic kv cache shape and
|
||||
# then permute it according to the stride order which could
|
||||
# result in a non-contiguous tensor.
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
# Maintain original KV shape view.
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape, dtype=dtype,
|
||||
device=self.device).permute(*inv_order)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
|
||||
Reference in New Issue
Block a user