[Mamba] Add stochastic rounding support (#35753)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -888,6 +888,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
cu_seqlens=query_start_loc_d,
|
||||
is_blackwell=self.is_blackwell,
|
||||
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
|
||||
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
|
||||
)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
Reference in New Issue
Block a user