Enable V1 for Hybrid SSM/Attention Models (#20016)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-07-04 19:46:53 +02:00
committed by GitHub
parent ffe00ef77a
commit 2f35a022e6
14 changed files with 399 additions and 134 deletions

View File

@@ -8,6 +8,7 @@ import torch
from torch import nn
from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -33,8 +34,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
config: FalconH1Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@@ -107,6 +108,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
activation=config.hidden_act,
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size,
)
# n_groups is overridden later by `MambaMixer2`
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
@@ -316,6 +319,7 @@ class FalconH1ParallelHybrid(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
# Instantiate the attention branch
self.self_attn = FalconH1AttentionDecoderLayer(
config=config,
@@ -323,11 +327,18 @@ class FalconH1ParallelHybrid(nn.Module):
quant_config=quant_config,
prefix=prefix,
)
# In V1 all attention/ssm layers must have
# different index in prefix
ssm_layer_idx = config.num_hidden_layers + layer_idx
ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}"
# Instantiate the SSM branch
self.mamba = FalconH1SSMDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=ssm_prefix,
)
self.ssm_out_multiplier = config.ssm_out_multiplier
self.ssm_in_multiplier = config.ssm_in_multiplier
@@ -452,10 +463,16 @@ class FalconH1Model(nn.Module):
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier
@@ -468,7 +485,9 @@ class FalconH1Model(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
layer_mamba_cache_params = None
if mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
@@ -484,7 +503,7 @@ class FalconH1Model(nn.Module):
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only):
IsHybrid):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -558,15 +577,19 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype if hasattr(
self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(
input_ids,
positions,