[v1] Support mamba2 (#19327)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-06-19 04:34:15 +08:00
committed by GitHub
parent ffacb222cb
commit a89209b78d
9 changed files with 582 additions and 120 deletions

View File

@@ -8,6 +8,7 @@ import torch
from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree,
SupportsV0Only)
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
@@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Mamba2DecoderLayer(config,
quant_config=quant_config),
lambda prefix: Mamba2DecoderLayer(
config, quant_config=quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size,
@@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank:
@@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
SupportsV0Only):
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
@@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)