[Neuron][kernel] Fuse kv cache into a single tensor (#15911)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
This commit is contained in:
Liangfu Chen
2025-04-03 09:51:32 -07:00
committed by GitHub
parent 82e7e19a6e
commit d2b58ca203
3 changed files with 46 additions and 56 deletions

View File

@@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
key_cache = torch.zeros_like(key_cache_cpu, device=device)
value_cache = torch.zeros_like(value_cache_cpu, device=device)
slot_mapping = slot_mapping_cpu.to(device)
kv_cache = torch.stack([key_cache, value_cache])
# Run vectorized implementation on XLA device
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
reshape_and_cache(key, value, kv_cache, slot_mapping)
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
# Move results back to CPU for comparison
key_cache_result = key_cache.cpu()