[v1] Support mamba2 (#19327)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree,
|
||||
SupportsV0Only)
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
|
||||
@@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.layer_norm_epsilon,
|
||||
activation=config.hidden_act,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.chunk_size)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Mamba2DecoderLayer(config,
|
||||
quant_config=quant_config),
|
||||
lambda prefix: Mamba2DecoderLayer(
|
||||
config, quant_config=quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
@@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
@@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer),
|
||||
i - self.start_layer) if mamba_cache_params else None,
|
||||
mamba2_metadata=mamba2_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
|
||||
SupportsV0Only):
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = (
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
else:
|
||||
# NOTE: mamba_cache_params is not needed for v1
|
||||
mamba_cache_params = None
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
Reference in New Issue
Block a user