Add Bamba Model (#10909)

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Yu Chin Fabian Lim
2025-02-07 07:22:42 +08:00
committed by GitHub
parent 467a96a541
commit aff404571b
17 changed files with 3706 additions and 112 deletions

View File

@@ -232,15 +232,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,