Decouple page_size_bytes calculation in AttentionSpec for TPU/RPA Compatibility. (#31635)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com> Signed-off-by: Lumosis <30372757+Lumosis@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,9 @@ pytestmark = pytest.mark.cpu_test
|
|||||||
|
|
||||||
|
|
||||||
def new_kv_cache_spec():
|
def new_kv_cache_spec():
|
||||||
return FullAttentionSpec(16, 1, 1, torch.float32, False)
|
return FullAttentionSpec(
|
||||||
|
block_size=16, num_kv_heads=1, head_size=1, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
|
def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
|
||||||
|
|||||||
@@ -94,7 +94,12 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
|||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer"],
|
["layer"],
|
||||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -109,18 +114,31 @@ def make_kv_cache_config_hybrid_model(
|
|||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer1"],
|
["layer1"],
|
||||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer2"],
|
["layer2"],
|
||||||
SlidingWindowSpec(
|
SlidingWindowSpec(
|
||||||
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=2 * block_size,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer3"],
|
["layer3"],
|
||||||
SlidingWindowSpec(
|
SlidingWindowSpec(
|
||||||
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=2 * block_size,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -1616,15 +1634,20 @@ def test_different_block_size():
|
|||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer1"],
|
["layer1"],
|
||||||
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
FullAttentionSpec(
|
||||||
|
block_size=block_size * 2,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float16,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer2"],
|
["layer2"],
|
||||||
SlidingWindowSpec(
|
SlidingWindowSpec(
|
||||||
block_size,
|
block_size=block_size,
|
||||||
1,
|
num_kv_heads=1,
|
||||||
1,
|
head_size=1,
|
||||||
torch.float32,
|
dtype=torch.float32,
|
||||||
sliding_window=2 * block_size,
|
sliding_window=2 * block_size,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -1573,7 +1573,13 @@ def create_scheduler_with_priority(
|
|||||||
kv_cache_tensors=[],
|
kv_cache_tensors=[],
|
||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
["layer"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -142,7 +142,13 @@ def create_scheduler(
|
|||||||
kv_cache_tensors=[],
|
kv_cache_tensors=[],
|
||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
["layer"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -148,7 +148,13 @@ def create_scheduler(
|
|||||||
kv_cache_tensors=[],
|
kv_cache_tensors=[],
|
||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(
|
||||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
["layer"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -61,14 +61,23 @@ class KVCacheSpec:
|
|||||||
return copy.deepcopy(specs[0])
|
return copy.deepcopy(specs[0])
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class AttentionSpec(KVCacheSpec):
|
class AttentionSpec(KVCacheSpec):
|
||||||
num_kv_heads: int
|
num_kv_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
page_size_padded: int | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def page_size_bytes(self) -> int:
|
||||||
|
real_page_size = self.real_page_size_bytes
|
||||||
|
if self.page_size_padded is not None:
|
||||||
|
assert self.page_size_padded >= real_page_size
|
||||||
|
return self.page_size_padded
|
||||||
|
return real_page_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def real_page_size_bytes(self) -> int:
|
||||||
return (
|
return (
|
||||||
2
|
2
|
||||||
* self.block_size
|
* self.block_size
|
||||||
@@ -78,7 +87,7 @@ class AttentionSpec(KVCacheSpec):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class FullAttentionSpec(AttentionSpec):
|
class FullAttentionSpec(AttentionSpec):
|
||||||
"""
|
"""
|
||||||
When hybrid allocator is disabled and the model contains both full
|
When hybrid allocator is disabled and the model contains both full
|
||||||
@@ -150,6 +159,7 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
head_size=specs[0].head_size,
|
head_size=specs[0].head_size,
|
||||||
head_size_v=specs[0].head_size_v,
|
head_size_v=specs[0].head_size_v,
|
||||||
dtype=specs[0].dtype,
|
dtype=specs[0].dtype,
|
||||||
|
page_size_padded=specs[0].page_size_padded,
|
||||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||||
)
|
)
|
||||||
@@ -168,7 +178,7 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
return merged_spec
|
return merged_spec
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def real_page_size_bytes(self) -> int:
|
||||||
return (
|
return (
|
||||||
self.block_size
|
self.block_size
|
||||||
* self.num_kv_heads
|
* self.num_kv_heads
|
||||||
@@ -177,13 +187,13 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class MLAAttentionSpec(FullAttentionSpec):
|
class MLAAttentionSpec(FullAttentionSpec):
|
||||||
# TODO(Lucas/Chen): less hacky way to do this
|
# TODO(Lucas/Chen): less hacky way to do this
|
||||||
cache_dtype_str: str | None = None
|
cache_dtype_str: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def real_page_size_bytes(self) -> int:
|
||||||
if self.cache_dtype_str == "fp8_ds_mla":
|
if self.cache_dtype_str == "fp8_ds_mla":
|
||||||
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
|
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
|
||||||
# for details.
|
# for details.
|
||||||
@@ -210,11 +220,12 @@ class MLAAttentionSpec(FullAttentionSpec):
|
|||||||
num_kv_heads=specs[0].num_kv_heads,
|
num_kv_heads=specs[0].num_kv_heads,
|
||||||
head_size=specs[0].head_size,
|
head_size=specs[0].head_size,
|
||||||
dtype=specs[0].dtype,
|
dtype=specs[0].dtype,
|
||||||
|
page_size_padded=specs[0].page_size_padded,
|
||||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||||
attention_chunk_size: int
|
attention_chunk_size: int
|
||||||
|
|
||||||
@@ -233,7 +244,7 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
|||||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class SlidingWindowSpec(AttentionSpec):
|
class SlidingWindowSpec(AttentionSpec):
|
||||||
sliding_window: int
|
sliding_window: int
|
||||||
|
|
||||||
@@ -335,6 +346,7 @@ class SinkFullAttentionSpec(FullAttentionSpec):
|
|||||||
head_size_v=specs[0].head_size_v,
|
head_size_v=specs[0].head_size_v,
|
||||||
sink_len=specs[0].sink_len,
|
sink_len=specs[0].sink_len,
|
||||||
dtype=specs[0].dtype,
|
dtype=specs[0].dtype,
|
||||||
|
page_size_padded=specs[0].page_size_padded,
|
||||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user