[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Chih-Chieh Yang
2025-04-10 15:07:07 -04:00
committed by GitHub
parent 5fbab20e02
commit daefed052c
8 changed files with 186 additions and 132 deletions

View File

@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import (
@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
chunk_size=config.chunk_size,
quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor],
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx)
mamba2_metadata)
return hidden_states, residual
@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
for i in range(len(self.layers)):
layer = self.layers[i]
@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
sequence_idx=seq_idx)
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({