[v1] - Mamba1 Attention Metadata (#21249)
Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
committed by
GitHub
parent
31f09c615f
commit
46a13949d5
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
extra_groups_for_head_shards, get_mamba_state_shape)
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
@@ -278,8 +278,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
self.n_groups = n_groups + extra_groups_for_head_shards(
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
@@ -732,7 +733,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return get_mamba_state_shape(
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=self.intermediate_size,
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
n_groups=self.n_groups,
|
||||
|
||||
Reference in New Issue
Block a user