[V1][Mamba1] - FP32 SSM Kernel Support (#23506)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-09-02 06:53:00 +03:00
committed by GitHub
parent 0235103cbb
commit 2b41cbbf03
3 changed files with 65 additions and 32 deletions

View File

@@ -30,12 +30,8 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires kernel changes
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
raise ValueError("fp32 state for mamba1 is not yet supported")
else:
return MambaStateDtypeCalculator.mamba2_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
@classmethod
def mamba2_state_dtype(
@@ -43,6 +39,16 @@ class MambaStateDtypeCalculator:
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
@classmethod
def _mamba_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)