Decouple page_size_bytes calculation in AttentionSpec for TPU/RPA Compatibility. (#31635)

Signed-off-by: Lihao Ran <imlihao.ran@gmail.com>
Signed-off-by: Lumosis <30372757+Lumosis@users.noreply.github.com>
This commit is contained in:
Lumosis
2026-01-08 01:00:07 -08:00
committed by GitHub
parent eac3b96ec0
commit b634e619bb
6 changed files with 75 additions and 20 deletions

View File

@@ -94,7 +94,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32),
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
@@ -109,18 +114,31 @@ def make_kv_cache_config_hybrid_model(
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size, 1, 1, torch.float32),
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
),
KVCacheGroupSpec(
["layer3"],
SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
),
],
@@ -1616,15 +1634,20 @@ def test_different_block_size():
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
FullAttentionSpec(
block_size=block_size * 2,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size,
1,
1,
torch.float32,
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
),