[v1] - Mamba1 Attention Metadata (#21249)

Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-07 03:03:42 +03:00
committed by GitHub
parent 31f09c615f
commit 46a13949d5
19 changed files with 367 additions and 161 deletions

View File

@@ -25,7 +25,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -457,7 +458,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,