[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
@dataclass(frozen=True)
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
dtypes: tuple[torch.dtype]
|
||||
page_size_padded: Optional[int] = None
|
||||
mamba_type: str = "mamba2"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
page_size = num_elements * get_dtype_size(self.dtype)
|
||||
page_size = sum(
|
||||
prod(shape) * get_dtype_size(dtype)
|
||||
for (shape, dtype) in zip(self.shapes, self.dtypes))
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
|
||||
Reference in New Issue
Block a user