[Model] Support Mamba2 (Codestral Mamba) (#9292)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Tyler Michael Smith
2025-02-17 07:17:50 -05:00
committed by GitHub
parent 7b623fca0b
commit 1f69c4a892
9 changed files with 376 additions and 65 deletions

View File

@@ -166,14 +166,13 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
self.scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config
self.vllm_config = vllm_config
self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config
self.backbone = MambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone"))
@@ -202,17 +201,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
@@ -229,8 +217,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)