[Mamba] Add stochastic rounding support (#35753)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-03-30 19:33:49 +03:00
committed by GitHub
parent dbdd9ae067
commit 8e6293e838
7 changed files with 166 additions and 2 deletions

View File

@@ -445,6 +445,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
# 4. Final linear projection