Decouple page_size_bytes calculation in AttentionSpec for TPU/RPA Compatibility. (#31635)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com> Signed-off-by: Lumosis <30372757+Lumosis@users.noreply.github.com>
This commit is contained in:
@@ -94,7 +94,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -109,18 +114,31 @@ def make_kv_cache_config_hybrid_model(
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer3"],
|
||||
SlidingWindowSpec(
|
||||
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -1616,15 +1634,20 @@ def test_different_block_size():
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||
FullAttentionSpec(
|
||||
block_size=block_size * 2,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user