[Mamba] Add stochastic rounding support (#35753)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-03-30 19:33:49 +03:00
committed by GitHub
parent dbdd9ae067
commit 8e6293e838
7 changed files with 166 additions and 2 deletions

View File

@@ -888,6 +888,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: