[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
@@ -6,10 +6,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.xformers import XFormersMetadata
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -18,6 +14,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
@@ -221,7 +218,6 @@ class MambaMixer2(CustomOp):
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
chunk_size: int = 256,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
@@ -257,7 +253,6 @@ class MambaMixer2(CustomOp):
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.activation = activation
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
@@ -388,25 +383,17 @@ class MambaMixer2(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
):
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# are the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
seq_len, _ = hidden_states.shape
|
||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||
|
||||
# detect if there are prefills
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
# - also need flags to indicate if there are initial states
|
||||
# - currently we really only support the FlashAttention backend
|
||||
has_initial_states = None
|
||||
if (isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, XFormersMetadata,
|
||||
PlaceholderAttentionMetadata))
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
gate, hidden_states_B_C, dt = torch.split(
|
||||
@@ -423,7 +410,7 @@ class MambaMixer2(CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if has_prefill:
|
||||
if mamba2_metadata.has_prefill:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
@@ -439,7 +426,7 @@ class MambaMixer2(CustomOp):
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=has_initial_states,
|
||||
has_initial_state=mamba2_metadata.has_initial_states,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc).transpose(
|
||||
0, 1)[:seq_len]
|
||||
@@ -467,16 +454,15 @@ class MambaMixer2(CustomOp):
|
||||
)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
if has_prefill:
|
||||
|
||||
if mamba2_metadata.has_prefill:
|
||||
initial_states = None
|
||||
if has_initial_states is not None and torch.any(
|
||||
has_initial_states):
|
||||
zero_init_indices = mamba_cache_params.state_indices_tensor[
|
||||
~has_initial_states]
|
||||
mamba_cache_params.ssm_state[zero_init_indices] = 0
|
||||
initial_states = mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor]
|
||||
if (mamba2_metadata.has_initial_states is not None
|
||||
and mamba2_metadata.prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
|
||||
@@ -485,11 +471,13 @@ class MambaMixer2(CustomOp):
|
||||
self.A,
|
||||
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
||||
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_size=mamba2_metadata.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=sequence_idx,
|
||||
seq_idx=mamba2_metadata.seq_idx,
|
||||
chunk_indices=mamba2_metadata.chunk_indices,
|
||||
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||
cu_seqlens=attn_metadata.query_start_loc,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
|
||||
Reference in New Issue
Block a user