[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Chih-Chieh Yang
2025-04-10 15:07:07 -04:00
committed by GitHub
parent 5fbab20e02
commit daefed052c
8 changed files with 186 additions and 132 deletions

View File

@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -495,7 +497,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
head_dim=intermediate_size // config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
chunk_size=config.chunk_size,
quant_config=quant_config,
)
@@ -507,7 +508,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
transformer_hidden_states: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
original_hidden_states: Optional[torch.Tensor] = None,
@@ -547,7 +548,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.mamba(
hidden_states,
mamba_cache_params=mamba_cache_params,
sequence_idx=sequence_idx,
mamba2_metadata=mamba2_metadata,
)
# residual connection after mamba
@@ -594,8 +595,8 @@ class Zamba2HybridLayer(nn.Module):
hidden_states: torch.Tensor,
original_hidden_states: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None,
sequence_idx: Optional[torch.Tensor] = None,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
) -> torch.Tensor:
"""Forward pass through the hybrid layer.
@@ -634,7 +635,7 @@ class Zamba2HybridLayer(nn.Module):
hidden_states,
transformer_hidden_states=transformer_hidden_states,
mamba_cache_params=mamba_cache_params,
sequence_idx=sequence_idx,
mamba2_metadata=mamba2_metadata,
)
return layer_outputs
@@ -747,20 +748,13 @@ class Zamba2Model(nn.Module):
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
# Process through layers
original_hidden_states = torch.clone(hidden_states)
@@ -770,7 +764,7 @@ class Zamba2Model(nn.Module):
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
sequence_idx=seq_idx,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs