[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-07-15 13:04:35 +02:00
committed by GitHub
parent c586b55667
commit 3534c39a20
14 changed files with 441 additions and 353 deletions

View File

@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
extra_groups_for_head_shards, get_mamba_state_shape)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -146,18 +148,6 @@ class Mixer2RMSNormGated(CustomOp):
return out
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
def mamba_v2_sharded_weight_loader(
shard_spec: list[tuple[int, int, float]],
tp_size: int,
@@ -707,30 +697,12 @@ class MambaMixer2(MambaBase, CustomOp):
return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.n_groups +
extra_groups_for_head_shards(self.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
# contiguous along 'dim' axis
conv_state_shape = (
self.conv_kernel_size - 1,
divide(conv_dim, world_size),
return get_mamba_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=self.n_groups,
num_heads=self.num_heads,
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.num_heads, world_size),
self.head_dim,
self.ssm_state_size,
)
return conv_state_shape, temporal_state_shape