[torch.compile] Hide KV cache behind torch.compile boundary (#11677)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-01-10 13:14:42 +08:00
committed by GitHub
parent 3de2b1eafb
commit cf5f000d21
18 changed files with 198 additions and 44 deletions

View File

@@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
if test_pt.attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER):
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])
attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache)