[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -11,15 +11,14 @@ from transformers import MambaConfig
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -198,6 +197,38 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- conv_state_shape: Shape for convolutional state cache
|
||||
- temporal_state_shape: Shape for state space model cache
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
intermediate_size = hf_config.expand * hf_config.hidden_size
|
||||
|
||||
return get_mamba_state_shape(
|
||||
intermediate_size=intermediate_size,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
n_groups=hf_config.n_groups,
|
||||
num_heads=hf_config.num_heads,
|
||||
head_dim=hf_config.head_dim,
|
||||
state_size=hf_config.state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@@ -253,9 +284,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
self.lm_head.weight.dtype,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
else:
|
||||
@@ -274,39 +309,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
conv_state_shape, temporal_state_shape = None, None
|
||||
|
||||
intermediate_size = getattr(
|
||||
self.config, "intermediate_size",
|
||||
self.config.expand * self.config.hidden_size)
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (
|
||||
self.config.n_groups +
|
||||
extra_groups_for_head_shards(self.config.n_groups, world_size))
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.config.conv_kernel - 1,
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(self.config.num_heads, world_size),
|
||||
self.config.head_dim,
|
||||
self.config.state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user