[Model] FalconMamba Support (#9325)

This commit is contained in:
Dhia Eddine Rhaiem
2024-10-21 20:50:16 +04:00
committed by GitHub
parent 496e991da8
commit f6b97293aa
5 changed files with 35 additions and 12 deletions

View File

@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
@@ -59,7 +59,7 @@ class MambaMixer(nn.Module):
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
@@ -109,6 +109,13 @@ class MambaMixer(nn.Module):
input_is_parallel=True,
)
self.activation = config.hidden_act
if self.is_falcon_mamba:
self.dt_layernorm = RMSNorm(self.time_step_rank,
eps=config.mixer_rms_eps)
self.b_layernorm = RMSNorm(self.ssm_state_size,
eps=config.mixer_rms_eps)
self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.mixer_rms_eps)
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
@@ -158,8 +165,12 @@ class MambaMixer(nn.Module):
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
# but Mamba doesn't.
if self.is_falcon_mamba:
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
@@ -213,11 +224,9 @@ class MambaDecoderLayer(nn.Module):
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.mixer = MambaMixer(config, layer_idx)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
@@ -319,8 +328,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = self.backbone.embeddings
if config.tie_word_embeddings:
self.lm_head = self.backbone.embeddings
else:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
@@ -398,7 +417,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue