[v1] - Mamba1 Attention Metadata (#21249)

Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-07 03:03:42 +03:00
committed by GitHub
parent 31f09c615f
commit 46a13949d5
19 changed files with 367 additions and 161 deletions

View File

@@ -8,6 +8,7 @@ import torch
from torch import nn
from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -19,6 +20,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -32,8 +35,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -112,7 +114,8 @@ class JambaMambaDecoderLayer(nn.Module):
use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
is_lora_enabled = self.is_lora_enabled,
prefix=f"{prefix}.mixer",
)
num_experts = config.layers_num_experts[layer_idx]
@@ -344,7 +347,8 @@ class JambaModel(nn.Module):
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer):
if isinstance(layer,
JambaMambaDecoderLayer) and mamba_cache_params:
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
@@ -442,7 +446,7 @@ class JambaModel(nn.Module):
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only):
IsHybrid):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".self_attn.": ".",
".A_log": ".A"
@@ -509,14 +513,19 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_layers, *state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
@@ -529,19 +538,22 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_conv - 1,
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
hidden_size = hf_config.hidden_size
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=parallel_config.tensor_parallel_size,
intermediate_size=hf_config.mamba_expand * hidden_size,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=envs.VLLM_USE_V1,
)
temporal_state_shape = (
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,