[v1][mamba] Added mamba_type into MambaSpec (#21715)
Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
committed by
GitHub
parent
139a7f07bd
commit
a6c050286a
@@ -200,13 +200,14 @@ class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
page_size_padded: Optional[int] = None
|
||||
mamba_type: str = "mamba2"
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mamba_{self.shapes}_{self.dtype}"
|
||||
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user