Decouple page_size_bytes calculation in AttentionSpec for TPU/RPA Compatibility. (#31635)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com> Signed-off-by: Lumosis <30372757+Lumosis@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,9 @@ pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def new_kv_cache_spec():
|
||||
return FullAttentionSpec(16, 1, 1, torch.float32, False)
|
||||
return FullAttentionSpec(
|
||||
block_size=16, num_kv_heads=1, head_size=1, dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
|
||||
|
||||
@@ -94,7 +94,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -109,18 +114,31 @@ def make_kv_cache_config_hybrid_model(
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer3"],
|
||||
SlidingWindowSpec(
|
||||
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -1616,15 +1634,20 @@ def test_different_block_size():
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||
FullAttentionSpec(
|
||||
block_size=block_size * 2,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
|
||||
@@ -1573,7 +1573,13 @@ def create_scheduler_with_priority(
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
["layer"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -142,7 +142,13 @@ def create_scheduler(
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
["layer"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -148,7 +148,13 @@ def create_scheduler(
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
["layer"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -61,14 +61,23 @@ class KVCacheSpec:
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
page_size_padded: int | None = None
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
real_page_size = self.real_page_size_bytes
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= real_page_size
|
||||
return self.page_size_padded
|
||||
return real_page_size
|
||||
|
||||
@property
|
||||
def real_page_size_bytes(self) -> int:
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
@@ -78,7 +87,7 @@ class AttentionSpec(KVCacheSpec):
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
@@ -150,6 +159,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
head_size=specs[0].head_size,
|
||||
head_size_v=specs[0].head_size_v,
|
||||
dtype=specs[0].dtype,
|
||||
page_size_padded=specs[0].page_size_padded,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
@@ -168,7 +178,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
return merged_spec
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
def real_page_size_bytes(self) -> int:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
@@ -177,13 +187,13 @@ class FullAttentionSpec(AttentionSpec):
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class MLAAttentionSpec(FullAttentionSpec):
|
||||
# TODO(Lucas/Chen): less hacky way to do this
|
||||
cache_dtype_str: str | None = None
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
def real_page_size_bytes(self) -> int:
|
||||
if self.cache_dtype_str == "fp8_ds_mla":
|
||||
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
|
||||
# for details.
|
||||
@@ -210,11 +220,12 @@ class MLAAttentionSpec(FullAttentionSpec):
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
page_size_padded=specs[0].page_size_padded,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
|
||||
@@ -233,7 +244,7 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
@@ -335,6 +346,7 @@ class SinkFullAttentionSpec(FullAttentionSpec):
|
||||
head_size_v=specs[0].head_size_v,
|
||||
sink_len=specs[0].sink_len,
|
||||
dtype=specs[0].dtype,
|
||||
page_size_padded=specs[0].page_size_padded,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user