[v1] - Mamba1 Attention Metadata (#21249)
Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
committed by
GitHub
parent
31f09c615f
commit
46a13949d5
@@ -8,20 +8,21 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree, SupportsPP,
|
||||
SupportsV0Only)
|
||||
IsAttentionFree, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@@ -41,7 +42,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
config: MambaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_lora_enabled: Optional[bool] = False) -> None:
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||
@@ -58,7 +60,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
rms_norm_has_weight=not self.is_falcon_mamba,
|
||||
rms_norm_eps=mixer_rms_eps,
|
||||
activation=config.hidden_act,
|
||||
is_lora_enabled=self.is_lora_enabled)
|
||||
is_lora_enabled=self.is_lora_enabled,
|
||||
prefix=f"{prefix}.mixer")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@@ -107,7 +110,8 @@ class MambaModel(nn.Module):
|
||||
lambda prefix: MambaDecoderLayer(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
is_lora_enabled=is_lora_enabled),
|
||||
is_lora_enabled=is_lora_enabled,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
@@ -123,7 +127,7 @@ class MambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -140,12 +144,17 @@ class MambaModel(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
|
||||
layer_cache_params = None
|
||||
if mamba_cache_params is not None:
|
||||
layer_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer))
|
||||
mamba_cache_params=layer_cache_params)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@@ -176,8 +185,7 @@ class MambaModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
SupportsV0Only):
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -227,20 +235,40 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_layers, *state_shape)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
intermediate_size=hf_config.intermediate_size,
|
||||
state_size=hf_config.state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
use_v1=envs.VLLM_USE_V1)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
@@ -248,19 +276,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
conv_state_shape = (
|
||||
self.config.intermediate_size // world_size,
|
||||
self.config.conv_kernel - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
self.config.intermediate_size // world_size,
|
||||
self.config.state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user