[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.layer_norm_epsilon,
|
||||
activation=config.hidden_act,
|
||||
chunk_size=config.chunk_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor],
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
sequence_idx)
|
||||
mamba2_metadata)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# pass a sequence index tensor, that is required for
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.query_start_loc[1:],
|
||||
)):
|
||||
seq_idx[srt:end] = i
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer),
|
||||
sequence_idx=seq_idx)
|
||||
mamba2_metadata=mamba2_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
|
||||
Reference in New Issue
Block a user