[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block (#22394)
Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
@@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
np.cumsum(slice_lens[:-1])])
|
||||
slot_mapping = np.stack(
|
||||
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
|
||||
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
|
||||
1) // num_slices_per_block * num_slices_per_block
|
||||
slot_mapping = np.pad(slot_mapping,
|
||||
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
|
||||
constant_values=0)
|
||||
slot_mapping = np.transpose(slot_mapping)
|
||||
slot_mapping_cpu = torch.tensor(slot_mapping,
|
||||
device="cpu",
|
||||
|
||||
Reference in New Issue
Block a user