[Mamba] Add stochastic rounding support (#35753)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -428,6 +428,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=scan_outputs_d,
|
||||
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
|
||||
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
|
||||
)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user