[v1][mamba] Added mamba_type into MambaSpec (#21715)

Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-07-28 11:15:55 +03:00
committed by GitHub
parent 139a7f07bd
commit a6c050286a
6 changed files with 52 additions and 4 deletions

View File

@@ -200,13 +200,14 @@ class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}"
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
@property
def page_size_bytes(self) -> int: