[TPU] kv cache update kernel supports dynamic grid (#20235)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-07-01 23:33:37 -07:00
committed by GitHub
parent b205e8467d
commit 7da296be04
4 changed files with 42 additions and 17 deletions

View File

@@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array([
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
@@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
device="cpu",
dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
device=torch_xla.device(),
dtype=torch.int32)
torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
num_slices_per_block)
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
page_size, num_slices_per_block)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()