[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2025-03-08 01:18:25 -05:00
committed by GitHub
parent ef64044079
commit 333681408f
3 changed files with 15 additions and 10 deletions

View File

@@ -23,9 +23,9 @@ class KVCacheSpecBase:
def type_id(self) -> str:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
@@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
@property
def type_id(self) -> str:
@@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
@property
def page_size_bytes(self) -> int:
return 2 * self.block_size * self.num_kv_heads * self.head_size \
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
@@ -104,7 +107,7 @@ class KVCacheConfig:
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
"""
groups: list[list[str]]