[v1] Support mamba2 (#19327)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mamba_{self.shapes}_{self.dtype}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return self.num_elements * get_dtype_size(self.dtype)
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# We allocate 1 block for each request now, so max_memory_usage_bytes is
|
||||
# the same as page_size_bytes.
|
||||
# Need to update this when supporting prefix caching.
|
||||
return self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user