[V1][Mamba1] - FP32 SSM Kernel Support (#23506)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0235103cbb
commit
2b41cbbf03
@@ -30,12 +30,8 @@ class MambaStateDtypeCalculator:
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires kernel changes
|
||||
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for mamba1 is not yet supported")
|
||||
else:
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
|
||||
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
@@ -43,6 +39,16 @@ class MambaStateDtypeCalculator:
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype)
|
||||
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||
model_dtype)
|
||||
|
||||
Reference in New Issue
Block a user