[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-07-15 13:04:35 +02:00
committed by GitHub
parent c586b55667
commit 3534c39a20
14 changed files with 441 additions and 353 deletions

View File

@@ -18,7 +18,7 @@ from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -30,8 +30,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -843,6 +843,39 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"1.weight": "B.weight",
})
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_ngroups,
num_heads=hf_config.n_mamba_heads,
head_dim=hf_config.mamba_headdim,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
"""Initialize the Zamba2 model for causal language modeling.
@@ -925,9 +958,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@@ -968,49 +1005,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"""
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
world_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.mamba_expand * self.config.hidden_size
# Extend groups if needed to ensure all groups needed by a head
# are sharded together
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards(
self.config.mamba_ngroups, world_size))
# Calculate conv state shape (includes groups)
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.mamba_d_state)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# Calculate temporal state shape (per-head states)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(divide(intermediate_size, self.config.mamba_headdim),
world_size),
self.config.mamba_headdim,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,