[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
extra_groups_for_head_shards, get_mamba_state_shape)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
@@ -146,18 +148,6 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
shard_spec: list[tuple[int, int, float]],
|
||||
tp_size: int,
|
||||
@@ -707,30 +697,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
return out
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
conv_state_shape, temporal_state_shape = None, None
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (self.n_groups +
|
||||
extra_groups_for_head_shards(self.n_groups, world_size))
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (self.intermediate_size +
|
||||
2 * n_groups * self.ssm_state_size)
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (
|
||||
self.conv_kernel_size - 1,
|
||||
divide(conv_dim, world_size),
|
||||
return get_mamba_state_shape(
|
||||
intermediate_size=self.intermediate_size,
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
n_groups=self.n_groups,
|
||||
num_heads=self.num_heads,
|
||||
head_dim=self.head_dim,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(self.num_heads, world_size),
|
||||
self.head_dim,
|
||||
self.ssm_state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
Reference in New Issue
Block a user