[Bugfix][V1] Handle MLA in kv_cache_interface (#14462)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
ef64044079
commit
333681408f
@@ -23,9 +23,9 @@ class KVCacheSpecBase:
|
||||
def type_id(self) -> str:
|
||||
"""
|
||||
The type identifier of this KV cache.
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
number of heads)
|
||||
|
||||
Returns:
|
||||
@@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
@@ -66,7 +67,9 @@ class FullAttentionSpec(KVCacheSpecBase):
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
@@ -104,7 +107,7 @@ class KVCacheConfig:
|
||||
2. (not implemented yet) A model with the same number of full attention
|
||||
layers and sliding window attention layers: two groups, one for full
|
||||
attention layers and one for sliding window attention layers.
|
||||
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
|
||||
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
|
||||
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
|
||||
"""
|
||||
groups: list[list[str]]
|
||||
|
||||
Reference in New Issue
Block a user