[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from transformers import BambaConfig
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -23,8 +23,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -435,6 +435,38 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- conv_state_shape: Shape for convolutional state cache
|
||||
- temporal_state_shape: Shape for state space model cache
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
num_heads=hf_config.mamba_n_heads,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
@@ -491,10 +523,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba
|
||||
)
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
@@ -510,38 +545,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = self.config.hidden_size
|
||||
|
||||
conv_state_shape, temporal_state_shape = None, None
|
||||
|
||||
intermediate_size = self.config.mamba_expand * hidden_size
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
|
||||
self.config.mamba_n_groups, world_size))
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (intermediate_size +
|
||||
2 * n_groups * self.config.mamba_d_state)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.config.mamba_d_conv - 1,
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(self.config.mamba_n_heads, world_size),
|
||||
self.config.mamba_d_head,
|
||||
self.config.mamba_d_state,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user