[Model] Support Mamba2 (Codestral Mamba) (#9292)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
committed by
GitHub
parent
7b623fca0b
commit
1f69c4a892
@@ -166,14 +166,13 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Mamba does not support prefix caching"
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.backbone = MambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "backbone"))
|
||||
@@ -202,17 +201,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.backbone.make_empty_intermediate_tensors)
|
||||
if self.scheduler_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
if self.scheduler_config.max_num_seqs > \
|
||||
vllm_config.compilation_config.max_capture_size:
|
||||
self.max_batch_size = \
|
||||
vllm_config.compilation_config.max_capture_size
|
||||
else:
|
||||
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
else:
|
||||
self.max_batch_size = 8192 + 2
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
@@ -229,8 +217,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers,
|
||||
self.max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user