[BugFix][kv_offload]: Fix kernel block size detection (#35125)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -259,16 +259,20 @@ class CpuGpuOffloadingHandlers:
|
||||
assert gpu_shape[0] == 2
|
||||
split_k_and_v = True
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=has_layers_dim
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(gpu_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(gpu_shape)))
|
||||
if has_layers_dim:
|
||||
# in the cross layers case, the registered kv cache tensor
|
||||
# shape matches the physical layout, whereas test_shape
|
||||
# is the logical layout.
|
||||
# To match them, we need to permute test_shape
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=has_layers_dim
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(gpu_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(gpu_shape)))
|
||||
|
||||
# permute test_shape according to stride_order
|
||||
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
|
||||
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
# find block_size (16) dimension index
|
||||
block_size_idx = test_shape.index(16)
|
||||
|
||||
Reference in New Issue
Block a user