Enable V1 for Hybrid SSM/Attention Models (#20016)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell
2025-07-04 19:46:53 +02:00
committed by GitHub
parent ffe00ef77a
commit 2f35a022e6
14 changed files with 399 additions and 134 deletions

View File

@@ -15,6 +15,7 @@ import torch
from torch import nn
from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -41,7 +42,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .interfaces import HasInnerState, IsHybrid
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -58,6 +59,7 @@ class Zamba2LoRA(nn.Module):
rank: int,
output_dim: Union[int, list[int]],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
"""Initialize the attention layer.
@@ -283,6 +285,7 @@ class Zamba2MLP(nn.Module):
bare_block_idx: int,
num_hybrid_layers: dict[int, int],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
"""Initialize the MLP layer.
@@ -471,11 +474,10 @@ class Zamba2MambaDecoderLayer(nn.Module):
computation depending on configuration.
"""
def __init__(
self,
config: Zamba2Config,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
def __init__(self,
config: Zamba2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
"""Initialize the Mamba decoder layer.
Args:
@@ -486,20 +488,21 @@ class Zamba2MambaDecoderLayer(nn.Module):
# Initialize Mamba mixer with expanded intermediate size
intermediate_size = config.mamba_expand * config.hidden_size
self.mamba = MambaMixer2(
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=intermediate_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.add_bias_linear,
n_groups=config.mamba_ngroups,
num_heads=config.n_mamba_heads,
head_dim=intermediate_size // config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
)
self.mamba = MambaMixer2(hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=intermediate_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.add_bias_linear,
n_groups=config.mamba_ngroups,
num_heads=config.n_mamba_heads,
head_dim=intermediate_size //
config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
# Input normalization
self.input_layernorm = RMSNorm(config.hidden_size,
@@ -573,6 +576,7 @@ class Zamba2HybridLayer(nn.Module):
config: Zamba2Config,
block_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
"""Initialize the hybrid layer.
@@ -589,7 +593,8 @@ class Zamba2HybridLayer(nn.Module):
bias=False,
quant_config=quant_config)
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
def forward(
self,
@@ -699,14 +704,23 @@ class Zamba2Model(nn.Module):
# Initialize layers according to block type configuration
layers = []
for layer_idx, layer_type in enumerate(config.layers_block_type):
# tdoublep: avoid layers getting same index
# somewhat hacky but correct (I think)
prefix = str(len(layer2block_map) + layer_idx)
if layer_type == "hybrid":
block = next(blocks)
block_idx = layer2block_map[layer_idx]
layers.append(
Zamba2HybridLayer(block, config, block_idx, quant_config))
Zamba2HybridLayer(block,
config,
block_idx,
quant_config,
prefix=prefix))
else:
layers.append(
Zamba2MambaDecoderLayer(config, quant_config=quant_config))
Zamba2MambaDecoderLayer(config,
quant_config=quant_config,
prefix=prefix))
self.layers = nn.ModuleList(layers)
# Final layer normalization
@@ -751,19 +765,30 @@ class Zamba2Model(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
# Process through layers
original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs
@@ -803,7 +828,7 @@ class Zamba2Model(nn.Module):
return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"""Zamba2 model with causal language modeling head.
This class wraps the core Zamba2 model and adds:
@@ -897,14 +922,16 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
Output hidden states
"""
# Initialize Mamba cache if needed
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Forward pass through model
hidden_states = self.model(