[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block (#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
Chengji Yao
2025-08-09 20:49:04 -07:00
committed by GitHub
parent 534c45b962
commit 2a84fb422f
3 changed files with 19 additions and 21 deletions

View File

@@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
np.cumsum(slice_lens[:-1])])
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
1) // num_slices_per_block * num_slices_per_block
slot_mapping = np.pad(slot_mapping,
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
constant_values=0)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping,
device="cpu",