Decouple page_size_bytes calculation in AttentionSpec for TPU/RPA Compatibility. (#31635)

Signed-off-by: Lihao Ran <imlihao.ran@gmail.com>
Signed-off-by: Lumosis <30372757+Lumosis@users.noreply.github.com>
This commit is contained in:
Lumosis
2026-01-08 01:00:07 -08:00
committed by GitHub
parent eac3b96ec0
commit b634e619bb
6 changed files with 75 additions and 20 deletions

View File

@@ -61,14 +61,23 @@ class KVCacheSpec:
return copy.deepcopy(specs[0])
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
page_size_padded: int | None = None
@property
def page_size_bytes(self) -> int:
real_page_size = self.real_page_size_bytes
if self.page_size_padded is not None:
assert self.page_size_padded >= real_page_size
return self.page_size_padded
return real_page_size
@property
def real_page_size_bytes(self) -> int:
return (
2
* self.block_size
@@ -78,7 +87,7 @@ class AttentionSpec(KVCacheSpec):
)
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class FullAttentionSpec(AttentionSpec):
"""
When hybrid allocator is disabled and the model contains both full
@@ -150,6 +159,7 @@ class FullAttentionSpec(AttentionSpec):
head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v,
dtype=specs[0].dtype,
page_size_padded=specs[0].page_size_padded,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
@@ -168,7 +178,7 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec
@property
def page_size_bytes(self) -> int:
def real_page_size_bytes(self) -> int:
return (
self.block_size
* self.num_kv_heads
@@ -177,13 +187,13 @@ class FullAttentionSpec(AttentionSpec):
)
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
cache_dtype_str: str | None = None
@property
def page_size_bytes(self) -> int:
def real_page_size_bytes(self) -> int:
if self.cache_dtype_str == "fp8_ds_mla":
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
@@ -210,11 +220,12 @@ class MLAAttentionSpec(FullAttentionSpec):
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
page_size_padded=specs[0].page_size_padded,
cache_dtype_str=cache_dtype_str_set.pop(),
)
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
@@ -233,7 +244,7 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
@@ -335,6 +346,7 @@ class SinkFullAttentionSpec(FullAttentionSpec):
head_size_v=specs[0].head_size_v,
sink_len=specs[0].sink_len,
dtype=specs[0].dtype,
page_size_padded=specs[0].page_size_padded,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)