[TPU] kv cache update kernel supports dynamic grid (#20235)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
||||
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
||||
dtype=np.int32)
|
||||
num_kv_update_slices = len(slice_lens)
|
||||
kv_cache_start_indices = np.array([
|
||||
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
||||
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
||||
@@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
||||
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
|
||||
device=torch_xla.device(),
|
||||
dtype=torch.int32)
|
||||
torch_xla.sync()
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
||||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
||||
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
|
||||
num_slices_per_block)
|
||||
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
|
||||
page_size, num_slices_per_block)
|
||||
kv_cache_xla.copy_(new_kv_cache_xla)
|
||||
torch_xla.sync()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user