[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -495,7 +497,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
head_dim=intermediate_size // config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
chunk_size=config.chunk_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
@@ -507,7 +508,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
transformer_hidden_states: Optional[torch.Tensor] = None,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
original_hidden_states: Optional[torch.Tensor] = None,
|
||||
@@ -547,7 +548,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
hidden_states = self.mamba(
|
||||
hidden_states,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
sequence_idx=sequence_idx,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
# residual connection after mamba
|
||||
@@ -594,8 +595,8 @@ class Zamba2HybridLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
original_hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the hybrid layer.
|
||||
|
||||
@@ -634,7 +635,7 @@ class Zamba2HybridLayer(nn.Module):
|
||||
hidden_states,
|
||||
transformer_hidden_states=transformer_hidden_states,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
sequence_idx=sequence_idx,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
return layer_outputs
|
||||
@@ -747,20 +748,13 @@ class Zamba2Model(nn.Module):
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# pass a sequence index tensor, that is required for
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.query_start_loc[1:],
|
||||
)):
|
||||
seq_idx[srt:end] = i
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Process through layers
|
||||
original_hidden_states = torch.clone(hidden_states)
|
||||
@@ -770,7 +764,7 @@ class Zamba2Model(nn.Module):
|
||||
original_hidden_states=original_hidden_states,
|
||||
positions=positions,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
|
||||
sequence_idx=seq_idx,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user