[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-08-15 14:57:06 +02:00
committed by GitHub
parent 22341b996e
commit 75531a6c13
23 changed files with 467 additions and 87 deletions

View File

@@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
dtypes: tuple[torch.dtype]
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
@property
def page_size_bytes(self) -> int:
num_elements = sum(prod(shape) for shape in self.shapes)
page_size = num_elements * get_dtype_size(self.dtype)
page_size = sum(
prod(shape) * get_dtype_size(dtype)
for (shape, dtype) in zip(self.shapes, self.dtypes))
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded