[v1] Support mamba2 (#19327)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -6,7 +6,9 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
@@ -227,20 +230,22 @@ class MambaMixer2(CustomOp):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
chunk_size: int = -1, # the chunk size used by v1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
|
||||
), "Tensor parallel currently not supported for quantized models."
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.activation = activation
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
@@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
assert chunk_size != -1, "chunk_size must be set for v1"
|
||||
|
||||
# NOTE: chunk_size may be -1 for models without v1 support
|
||||
self.chunk_size = chunk_size
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0]
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx
|
||||
chunk_indices_p = attn_metadata.chunk_indices
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
||||
|
||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||
|
||||
@@ -459,27 +501,6 @@ class MambaMixer2(CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
mamba_cache_params.state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
if has_prefill else None)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
@@ -491,20 +512,80 @@ class MambaMixer2(CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||
0, 1).clone().transpose(0, 1)).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C)
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
else:
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
||||
1]
|
||||
if has_prefill else None)
|
||||
|
||||
ssd_output_list = []
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
# pointed to by "state_indices_tensor"
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
hidden_states_B_C_p.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=mamba2_metadata.has_initial_states,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
@@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (mamba2_metadata.has_initial_states is not None
|
||||
and mamba2_metadata.prep_initial_states):
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
@@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
|
||||
-1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
chunk_size=mamba2_metadata.chunk_size,
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=mamba2_metadata.seq_idx,
|
||||
chunk_indices=mamba2_metadata.chunk_indices,
|
||||
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
|
||||
seq_idx=seq_idx_p,
|
||||
chunk_indices=chunk_indices_p,
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
@@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
@@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
@@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states_d = selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
@@ -598,9 +678,16 @@ class MambaMixer2(CustomOp):
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
)
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
ssd_output_list.insert(
|
||||
0,
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
else:
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# Merge prefill and decode outputs before passing to gated MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
@@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
|
||||
# 5. Final linear projection
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
conv_state_shape, temporal_state_shape = None, None
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (self.n_groups +
|
||||
extra_groups_for_head_shards(self.n_groups, world_size))
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (self.intermediate_size +
|
||||
2 * n_groups * self.ssm_state_size)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.conv_kernel_size - 1,
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(self.num_heads, world_size),
|
||||
self.head_dim,
|
||||
self.ssm_state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
Reference in New Issue
Block a user