[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-07-15 13:04:35 +02:00
committed by GitHub
parent c586b55667
commit 3534c39a20
14 changed files with 441 additions and 353 deletions

View File

@@ -11,15 +11,14 @@ from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -198,6 +197,38 @@ class Mamba2Model(nn.Module):
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.n_groups,
num_heads=hf_config.num_heads,
head_dim=hf_config.head_dim,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@@ -253,9 +284,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
@@ -274,39 +309,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
intermediate_size = getattr(
self.config, "intermediate_size",
self.config.expand * self.config.hidden_size)
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (
self.config.n_groups +
extra_groups_for_head_shards(self.config.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.conv_kernel - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.num_heads, world_size),
self.config.head_dim,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,