[Model] Support Mamba2 (Codestral Mamba) (#9292)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
committed by
GitHub
parent
7b623fca0b
commit
1f69c4a892
@@ -440,23 +440,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# follow jamba
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
# for compilation
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
elif self.scheduler_config is not None:
|
||||
# for eager just take the scheduler_config if avail
|
||||
self.max_batch_size = self.scheduler_config.max_num_seqs
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@@ -474,8 +457,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
|
||||
Reference in New Issue
Block a user