[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (#15423)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
This commit is contained in:
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
MambaMixer2, extra_groups_for_head_shards)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -94,7 +96,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.hidden_act,
|
||||
chunk_size=config.mamba_chunk_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.feed_forward = BambaMLP(config, quant_config=quant_config)
|
||||
@@ -108,7 +109,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
sequence_idx: Optional[torch.Tensor] = None,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@@ -119,7 +120,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
sequence_idx)
|
||||
mamba2_metadata)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
@@ -259,7 +260,7 @@ class BambaModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
config: BambaConfig = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@@ -309,20 +310,13 @@ class BambaModel(nn.Module):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# pass a sequence index tensor, that is required for
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
seq_idx = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.num_prefills > 0:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.query_start_loc[1:],
|
||||
)):
|
||||
seq_idx[srt:end] = i
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
@@ -352,7 +346,7 @@ class BambaModel(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
sequence_idx=seq_idx,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -555,4 +549,4 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user