Add Bamba Model (#10909)
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
467a96a541
commit
aff404571b
@@ -232,15 +232,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
|
||||
Reference in New Issue
Block a user