[Mamba] Add stochastic rounding support (#35753)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-03-30 19:33:49 +03:00
committed by GitHub
parent dbdd9ae067
commit 8e6293e838
7 changed files with 166 additions and 2 deletions

View File

@@ -428,6 +428,8 @@ class MambaMixer(MambaBase, PluggableLayer):
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)