Enable V1 for Hybrid SSM/Attention Models (#20016)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import FalconH1Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -33,8 +34,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
config: FalconH1Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -107,6 +108,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
activation=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
use_rms_norm=config.mamba_rms_norm,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.mamba_chunk_size,
|
||||
)
|
||||
# n_groups is overridden later by `MambaMixer2`
|
||||
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
||||
@@ -316,6 +319,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Instantiate the attention branch
|
||||
self.self_attn = FalconH1AttentionDecoderLayer(
|
||||
config=config,
|
||||
@@ -323,11 +327,18 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
# In V1 all attention/ssm layers must have
|
||||
# different index in prefix
|
||||
ssm_layer_idx = config.num_hidden_layers + layer_idx
|
||||
ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}"
|
||||
|
||||
# Instantiate the SSM branch
|
||||
self.mamba = FalconH1SSMDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=ssm_prefix,
|
||||
)
|
||||
self.ssm_out_multiplier = config.ssm_out_multiplier
|
||||
self.ssm_in_multiplier = config.ssm_in_multiplier
|
||||
@@ -452,10 +463,16 @@ class FalconH1Model(nn.Module):
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||
@@ -468,7 +485,9 @@ class FalconH1Model(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
||||
layer_mamba_cache_params = None
|
||||
if mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@@ -484,7 +503,7 @@ class FalconH1Model(nn.Module):
|
||||
|
||||
|
||||
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsV0Only):
|
||||
IsHybrid):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -558,15 +577,19 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.mamba_cache is None:
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
self.lm_head.weight.dtype
|
||||
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
|
||||
self.config.num_hidden_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
self.lm_head.weight.dtype if hasattr(
|
||||
self.lm_head, 'weight') else torch.bfloat16,
|
||||
self.config.num_hidden_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
|
||||
Reference in New Issue
Block a user