[v1] Support mamba2 (#19327)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-06-19 04:34:15 +08:00
committed by GitHub
parent ffacb222cb
commit a89209b78d
9 changed files with 582 additions and 120 deletions

View File

@@ -6,7 +6,9 @@ from typing import Optional, Union
import torch
from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
@@ -227,20 +230,22 @@ class MambaMixer2(CustomOp):
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
):
super().__init__()
@@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
self.activation = activation
self.intermediate_size = intermediate_size
@@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
self.use_rms_norm,
eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert chunk_size != -1, "chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self.chunk_size = chunk_size
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
@@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0]
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx
chunk_indices_p = attn_metadata.chunk_indices
chunk_offsets_p = attn_metadata.chunk_offsets
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
groups_time_state_size = self.n_groups * self.ssm_state_size
@@ -459,27 +501,6 @@ class MambaMixer2(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
@@ -491,20 +512,80 @@ class MambaMixer2(CustomOp):
dim=-1,
)
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
ssd_output_list = []
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
# pointed to by "state_indices_tensor"
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
@@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
# 3. State Space Model sequence transformation
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
if (has_initial_states_p is not None and prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
@@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=mamba2_metadata.chunk_size,
chunk_size=chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
seq_idx=seq_idx_p,
chunk_indices=chunk_indices_p,
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
@@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
@@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
mamba_cache_params.conv_state,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
@@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state,
ssm_state,
hidden_states_d,
dt_d,
A_d,
@@ -598,9 +678,16 @@ class MambaMixer2(CustomOp):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
)
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
if envs.VLLM_USE_V1:
ssd_output_list.insert(
0,
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
else:
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)
@@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.n_groups +
extra_groups_for_head_shards(self.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.num_heads, world_size),
self.head_dim,
self.ssm_state_size,
)
return conv_state_shape, temporal_state_shape