From b634e619bbcfed0abe4e01d0e2d97fb1fdfdbbd5 Mon Sep 17 00:00:00 2001 From: Lumosis <30372757+Lumosis@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:00:07 -0800 Subject: [PATCH] Decouple page_size_bytes calculation in AttentionSpec for TPU/RPA Compatibility. (#31635) Signed-off-by: Lihao Ran Signed-off-by: Lumosis <30372757+Lumosis@users.noreply.github.com> --- tests/v1/core/test_kv_sharing.py | 4 ++- tests/v1/core/test_prefix_caching.py | 41 ++++++++++++++++++++++------ tests/v1/core/test_scheduler.py | 8 +++++- tests/v1/core/utils.py | 8 +++++- tests/v1/kv_connector/unit/utils.py | 8 +++++- vllm/v1/kv_cache_interface.py | 26 +++++++++++++----- 6 files changed, 75 insertions(+), 20 deletions(-) diff --git a/tests/v1/core/test_kv_sharing.py b/tests/v1/core/test_kv_sharing.py index e6d37b1d6..db0e8dae3 100644 --- a/tests/v1/core/test_kv_sharing.py +++ b/tests/v1/core/test_kv_sharing.py @@ -11,7 +11,9 @@ pytestmark = pytest.mark.cpu_test def new_kv_cache_spec(): - return FullAttentionSpec(16, 1, 1, torch.float32, False) + return FullAttentionSpec( + block_size=16, num_kv_heads=1, head_size=1, dtype=torch.float32 + ) def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 977ec71bc..0600b813a 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -94,7 +94,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32), + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), ) ], ) @@ -109,18 +114,31 @@ def make_kv_cache_config_hybrid_model( kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32), + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), ), KVCacheGroupSpec( ["layer2"], SlidingWindowSpec( - block_size, 1, 1, torch.float32, sliding_window=2 * block_size + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * block_size, ), ), KVCacheGroupSpec( ["layer3"], SlidingWindowSpec( - block_size, 1, 1, torch.float32, sliding_window=2 * block_size + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * block_size, ), ), ], @@ -1616,15 +1634,20 @@ def test_different_block_size(): kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size * 2, 1, 1, torch.float16), + FullAttentionSpec( + block_size=block_size * 2, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + ), ), KVCacheGroupSpec( ["layer2"], SlidingWindowSpec( - block_size, - 1, - 1, - torch.float32, + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, sliding_window=2 * block_size, ), ), diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index b44c2a1c3..9f1b7bf41 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1573,7 +1573,13 @@ def create_scheduler_with_priority( kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( - ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ["layer"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), ) ], ) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 531b9c595..bcc689070 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -142,7 +142,13 @@ def create_scheduler( kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( - ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ["layer"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), ) ], ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 3a0dbb8e4..e754a0917 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -148,7 +148,13 @@ def create_scheduler( kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( - ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ["layer"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), ) ], ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 7a7bb9036..5c9913bb0 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -61,14 +61,23 @@ class KVCacheSpec: return copy.deepcopy(specs[0]) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype + page_size_padded: int | None = None @property def page_size_bytes(self) -> int: + real_page_size = self.real_page_size_bytes + if self.page_size_padded is not None: + assert self.page_size_padded >= real_page_size + return self.page_size_padded + return real_page_size + + @property + def real_page_size_bytes(self) -> int: return ( 2 * self.block_size @@ -78,7 +87,7 @@ class AttentionSpec(KVCacheSpec): ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class FullAttentionSpec(AttentionSpec): """ When hybrid allocator is disabled and the model contains both full @@ -150,6 +159,7 @@ class FullAttentionSpec(AttentionSpec): head_size=specs[0].head_size, head_size_v=specs[0].head_size_v, dtype=specs[0].dtype, + page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -168,7 +178,7 @@ class FullAttentionSpec(AttentionSpec): return merged_spec @property - def page_size_bytes(self) -> int: + def real_page_size_bytes(self) -> int: return ( self.block_size * self.num_kv_heads @@ -177,13 +187,13 @@ class FullAttentionSpec(AttentionSpec): ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this cache_dtype_str: str | None = None @property - def page_size_bytes(self) -> int: + def real_page_size_bytes(self) -> int: if self.cache_dtype_str == "fp8_ds_mla": # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` # for details. @@ -210,11 +220,12 @@ class MLAAttentionSpec(FullAttentionSpec): num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, + page_size_padded=specs[0].page_size_padded, cache_dtype_str=cache_dtype_str_set.pop(), ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int @@ -233,7 +244,7 @@ class ChunkedLocalAttentionSpec(AttentionSpec): return cdiv(num_tokens, self.block_size) * self.page_size_bytes -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SlidingWindowSpec(AttentionSpec): sliding_window: int @@ -335,6 +346,7 @@ class SinkFullAttentionSpec(FullAttentionSpec): head_size_v=specs[0].head_size_v, sink_len=specs[0].sink_len, dtype=specs[0].dtype, + page_size_padded=specs[0].page_size_padded, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), )