[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from transformers import Zamba2Config
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -30,8 +30,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -843,6 +843,39 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"1.weight": "B.weight",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- conv_state_shape: Shape for convolutional state cache
|
||||
- temporal_state_shape: Shape for state space model cache
|
||||
"""
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.mamba_ngroups,
|
||||
num_heads=hf_config.n_mamba_heads,
|
||||
head_dim=hf_config.mamba_headdim,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
"""Initialize the Zamba2 model for causal language modeling.
|
||||
|
||||
@@ -925,9 +958,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.config.num_hidden_layers
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
|
||||
# Get cache parameters for current run
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
@@ -968,49 +1005,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"""
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- conv_state_shape: Shape for convolutional state cache
|
||||
- temporal_state_shape: Shape for state space model cache
|
||||
"""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.mamba_expand * self.config.hidden_size
|
||||
|
||||
# Extend groups if needed to ensure all groups needed by a head
|
||||
# are sharded together
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards(
|
||||
self.config.mamba_ngroups, world_size))
|
||||
|
||||
# Calculate conv state shape (includes groups)
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (intermediate_size +
|
||||
2 * n_groups * self.config.mamba_d_state)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.config.mamba_d_conv - 1,
|
||||
)
|
||||
|
||||
# Calculate temporal state shape (per-head states)
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(divide(intermediate_size, self.config.mamba_headdim),
|
||||
world_size),
|
||||
self.config.mamba_headdim,
|
||||
self.config.mamba_d_state,
|
||||
)
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user