[v1] - Mamba1 Attention Metadata (#21249)

Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-07 03:03:42 +03:00
committed by GitHub
parent 31f09c615f
commit 46a13949d5
19 changed files with 367 additions and 161 deletions

View File

@@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
extra_groups_for_head_shards, get_mamba_state_shape)
MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@@ -278,8 +278,9 @@ class MambaMixer2(MambaBase, CustomOp):
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
self.n_groups = n_groups + extra_groups_for_head_shards(
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size)
self.n_groups = n_groups + groups
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
self.conv1d = ColumnParallelLinear(
@@ -732,7 +733,7 @@ class MambaMixer2(MambaBase, CustomOp):
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return get_mamba_state_shape(
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=self.n_groups,