[Neuron][kernel] Fuse kv cache into a single tensor (#15911)
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
This commit is contained in:
@@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
|
||||
key_cache = torch.zeros_like(key_cache_cpu, device=device)
|
||||
value_cache = torch.zeros_like(value_cache_cpu, device=device)
|
||||
slot_mapping = slot_mapping_cpu.to(device)
|
||||
kv_cache = torch.stack([key_cache, value_cache])
|
||||
|
||||
# Run vectorized implementation on XLA device
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
reshape_and_cache(key, value, kv_cache, slot_mapping)
|
||||
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
|
||||
|
||||
# Move results back to CPU for comparison
|
||||
key_cache_result = key_cache.cpu()
|
||||
|
||||
Reference in New Issue
Block a user