Add Bamba Model (#10909)
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
467a96a541
commit
aff404571b
@@ -455,14 +455,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
||||
**kwargs)
|
||||
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
Reference in New Issue
Block a user