[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
Chih-Chieh Yang
2025-04-10 15:07:07 -04:00
committed by GitHub
parent 5fbab20e02
commit daefed052c
8 changed files with 186 additions and 132 deletions

View File

@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
chunk_size=config.mamba_chunk_size,
quant_config=quant_config)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual)
hidden_states = self.mamba(hidden_states, mamba_cache_params,
sequence_idx)
mamba2_metadata)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
@@ -259,7 +260,7 @@ class BambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config: BambaConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@@ -309,20 +310,13 @@ class BambaModel(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@@ -352,7 +346,7 @@ class BambaModel(nn.Module):
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params,
sequence_idx=seq_idx,
mamba2_metadata=mamba2_metadata,
)
if not get_pp_group().is_last_rank:
@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights)