diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index a5abae51e..5cde5faa4 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -259,16 +259,20 @@ class CpuGpuOffloadingHandlers: assert gpu_shape[0] == 2 split_k_and_v = True - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=has_layers_dim - ) - assert len(kv_cache_stride_order) == len(gpu_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(gpu_shape))) + if has_layers_dim: + # in the cross layers case, the registered kv cache tensor + # shape matches the physical layout, whereas test_shape + # is the logical layout. + # To match them, we need to permute test_shape + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=has_layers_dim + ) + assert len(kv_cache_stride_order) == len(gpu_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(gpu_shape))) - # permute test_shape according to stride_order - test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) + test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) # find block_size (16) dimension index block_size_idx = test_shape.index(16)