Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""PyTorch Zamba2 model implementation for vLLM.
|
||||
|
||||
This module implements the Zamba2 architecture from
|
||||
https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
|
||||
architectures in a hybrid model optimized for efficient sequence modeling. The
|
||||
This module implements the Zamba2 architecture from
|
||||
https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
|
||||
architectures in a hybrid model optimized for efficient sequence modeling. The
|
||||
model alternates between state space model layers and attention-based layers.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import cycle
|
||||
from typing import Any, Optional, Union
|
||||
@@ -21,19 +22,26 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -43,7 +51,7 @@ from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
class Zamba2LoRA(nn.Module):
|
||||
"""LoRA layer for the Zamba2 model.
|
||||
|
||||
|
||||
Implements a LoRA layer that is used in shared attention and gated MLP
|
||||
blocks.
|
||||
"""
|
||||
@@ -57,7 +65,7 @@ class Zamba2LoRA(nn.Module):
|
||||
prefix: str = "",
|
||||
):
|
||||
"""Initialize the attention layer.
|
||||
|
||||
|
||||
Args:
|
||||
input_dim: input dimension
|
||||
rank: LoRA rank
|
||||
@@ -66,20 +74,15 @@ class Zamba2LoRA(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.A = ColumnParallelLinear(input_dim,
|
||||
rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True)
|
||||
self.A = ColumnParallelLinear(
|
||||
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
|
||||
)
|
||||
|
||||
if isinstance(output_dim, list):
|
||||
B_class = MergedColumnParallelLinear
|
||||
else:
|
||||
B_class = ColumnParallelLinear
|
||||
self.B = B_class(rank,
|
||||
output_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -92,8 +95,8 @@ class Zamba2LoRA(nn.Module):
|
||||
|
||||
class Zamba2Attention(nn.Module):
|
||||
"""Multi-head attention mechanism for the Zamba2 model.
|
||||
|
||||
Implements attention with parallel computation, QKV projections, optional
|
||||
|
||||
Implements attention with parallel computation, QKV projections, optional
|
||||
adapters and rotary position embeddings. The attention is computed across
|
||||
distributed blocks for efficient processing.
|
||||
"""
|
||||
@@ -108,7 +111,7 @@ class Zamba2Attention(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the attention layer.
|
||||
|
||||
|
||||
Args:
|
||||
config: The Zamba2 model configuration
|
||||
bare_block_idx: Index of the bare attention block
|
||||
@@ -129,15 +132,17 @@ class Zamba2Attention(nn.Module):
|
||||
self.num_attention_heads = config.num_attention_heads // tp_size
|
||||
self.attention_head_dim = config.attention_head_dim
|
||||
self.qkv_size = self.attention_hidden_size // tp_size
|
||||
self.scale = (self.attention_head_dim / 2)**-0.5
|
||||
self.scale = (self.attention_head_dim / 2) ** -0.5
|
||||
|
||||
if (self.attention_head_dim *
|
||||
self.total_num_attention_heads) != self.attention_hidden_size:
|
||||
if (
|
||||
self.attention_head_dim * self.total_num_attention_heads
|
||||
) != self.attention_hidden_size:
|
||||
raise ValueError(
|
||||
f"attention_hidden_size must be divisible by"
|
||||
f" num_attention_heads"
|
||||
f" (got `attention_hidden_size`: {self.attention_hidden_size}"
|
||||
f" and `num_heads`: {self.num_attention_heads}).")
|
||||
f" and `num_heads`: {self.num_attention_heads})."
|
||||
)
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.attention_hidden_size,
|
||||
@@ -146,10 +151,12 @@ class Zamba2Attention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(self.attention_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.attention_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Even though in Zamba2 weights are shared between attention layers, KV
|
||||
# cache is unique for every attention layer. Hence, we need to define
|
||||
@@ -158,8 +165,11 @@ class Zamba2Attention(nn.Module):
|
||||
|
||||
# Initialize attention blocks with proper indexing
|
||||
self.dpa_list = nn.ModuleList([])
|
||||
j = bare_block_idx * (self.num_hybrid_layers + config.num_mem_blocks -
|
||||
1) // config.num_mem_blocks
|
||||
j = (
|
||||
bare_block_idx
|
||||
* (self.num_hybrid_layers + config.num_mem_blocks - 1)
|
||||
// config.num_mem_blocks
|
||||
)
|
||||
for block_idx in range(self.num_hybrid_layers):
|
||||
if block_idx % config.num_mem_blocks == bare_block_idx:
|
||||
dpa = Attention(
|
||||
@@ -226,18 +236,17 @@ class Zamba2Attention(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the attention layer.
|
||||
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||
position_ids: Position IDs for positional embeddings
|
||||
block_idx: Current shared transformer block index
|
||||
|
||||
|
||||
Returns:
|
||||
Output tensor [batch_size, seq_len, hidden_size]
|
||||
"""
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split([self.qkv_size] * 3,
|
||||
dim=-1)
|
||||
query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, dim=-1)
|
||||
|
||||
if self.config.use_shared_attention_adapter:
|
||||
# Apply adapter transformations to Q, K, V if enabled
|
||||
@@ -257,9 +266,9 @@ class Zamba2Attention(nn.Module):
|
||||
value_states = value_states + v_lora_output
|
||||
|
||||
if self.config.use_mem_rope:
|
||||
query_states, key_states = self.rotary_emb(position_ids,
|
||||
query_states,
|
||||
key_states)
|
||||
query_states, key_states = self.rotary_emb(
|
||||
position_ids, query_states, key_states
|
||||
)
|
||||
|
||||
y = self.dpa_list[block_idx](query_states, key_states, value_states)
|
||||
y, _ = self.o_proj(y)
|
||||
@@ -268,9 +277,9 @@ class Zamba2Attention(nn.Module):
|
||||
|
||||
class Zamba2MLP(nn.Module):
|
||||
"""Feed-forward MLP layer for the Zamba2 model.
|
||||
|
||||
Implements a gated feed-forward network that projects inputs to a larger
|
||||
intermediate size, applies GELU activation with gating, then projects back
|
||||
|
||||
Implements a gated feed-forward network that projects inputs to a larger
|
||||
intermediate size, applies GELU activation with gating, then projects back
|
||||
to the original size. Includes optional adapter layers for model adaptation.
|
||||
"""
|
||||
|
||||
@@ -283,7 +292,7 @@ class Zamba2MLP(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the MLP layer.
|
||||
|
||||
|
||||
Args:
|
||||
config: The Zamba2 model configuration
|
||||
bare_block_idx: Index of the bare block in the model
|
||||
@@ -302,17 +311,22 @@ class Zamba2MLP(nn.Module):
|
||||
self.hidden_size,
|
||||
2 * [self.intermediate_size], # 2x for gate and input projections
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Only allow GELU activations
|
||||
if config.hidden_act != "gelu":
|
||||
raise ValueError(f"Only GELU activation is supported "
|
||||
f"(got `hidden_act`: {config.hidden_act})")
|
||||
raise ValueError(
|
||||
f"Only GELU activation is supported "
|
||||
f"(got `hidden_act`: {config.hidden_act})"
|
||||
)
|
||||
self.act_fn = GeluAndMul()
|
||||
|
||||
# Initialize adapter layers
|
||||
@@ -329,14 +343,13 @@ class Zamba2MLP(nn.Module):
|
||||
gate_up_proj_adapter = nn.Identity()
|
||||
self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
block_idx: int) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, block_idx: int) -> torch.Tensor:
|
||||
"""Forward pass through the MLP layer.
|
||||
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||
block_idx: Current shared transformer block index
|
||||
|
||||
|
||||
Returns:
|
||||
Output tensor [batch_size, seq_len, hidden_size] after applying
|
||||
gated feed-forward transformation
|
||||
@@ -360,7 +373,7 @@ class Zamba2MLP(nn.Module):
|
||||
|
||||
class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
"""Single decoder layer combining attention and feed-forward networks.
|
||||
|
||||
|
||||
This layer implements a standard transformer block with:
|
||||
- Input layer normalization
|
||||
- Multi-head self-attention
|
||||
@@ -378,7 +391,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the decoder layer.
|
||||
|
||||
|
||||
Args:
|
||||
config: The Zamba2 model configuration
|
||||
bare_block_idx: Index of the bare block
|
||||
@@ -409,11 +422,9 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
|
||||
# Initialize layer normalizations
|
||||
# Input normalization operates on concatenated states
|
||||
self.input_layernorm = RMSNorm(2 * config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(2 * config.hidden_size, eps=config.rms_norm_eps)
|
||||
# Pre-FF normalization operates on attention output
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -423,14 +434,14 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the decoder layer.
|
||||
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor from previous layer
|
||||
original_hidden_states: Original input tensor for residual
|
||||
original_hidden_states: Original input tensor for residual
|
||||
connection
|
||||
block_idx: Current shared transformer block index
|
||||
positions: IDs for positional embeddings
|
||||
|
||||
|
||||
Returns:
|
||||
Transformed hidden states after attention and feed-forward
|
||||
"""
|
||||
@@ -440,7 +451,8 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
# The concatenated tensor is then used as input of the pre-attention
|
||||
# RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712).
|
||||
hidden_states = torch.concatenate(
|
||||
[hidden_states, original_hidden_states], dim=-1)
|
||||
[hidden_states, original_hidden_states], dim=-1
|
||||
)
|
||||
|
||||
# Layer norm before attention
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@@ -463,20 +475,22 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
|
||||
class Zamba2MambaDecoderLayer(nn.Module):
|
||||
"""Single Mamba decoder layer with normalization.
|
||||
|
||||
This implements a Mamba block. It includes input normalization
|
||||
and can process sequences using either chunked or full
|
||||
|
||||
This implements a Mamba block. It includes input normalization
|
||||
and can process sequences using either chunked or full
|
||||
computation depending on configuration.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Zamba2Config,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: Zamba2Config,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the Mamba decoder layer.
|
||||
|
||||
|
||||
Args:
|
||||
config: The Zamba2 model configuration
|
||||
quant_config: Configuration for model quantization
|
||||
@@ -485,26 +499,26 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
|
||||
# Initialize Mamba mixer with expanded intermediate size
|
||||
intermediate_size = config.mamba_expand * config.hidden_size
|
||||
self.mamba = MambaMixer2(hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.mamba_d_state,
|
||||
conv_kernel_size=config.mamba_d_conv,
|
||||
intermediate_size=intermediate_size,
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.add_bias_linear,
|
||||
n_groups=config.mamba_ngroups,
|
||||
num_heads=config.n_mamba_heads,
|
||||
head_dim=intermediate_size //
|
||||
config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer")
|
||||
self.mamba = MambaMixer2(
|
||||
hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.mamba_d_state,
|
||||
conv_kernel_size=config.mamba_d_conv,
|
||||
intermediate_size=intermediate_size,
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.add_bias_linear,
|
||||
n_groups=config.mamba_ngroups,
|
||||
num_heads=config.n_mamba_heads,
|
||||
head_dim=intermediate_size // config.n_mamba_heads,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation="silu",
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
# Input normalization
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -514,14 +528,14 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
original_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the Mamba decoder layer.
|
||||
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||
transformer_hidden_states: Optional output from transformer path
|
||||
Added to input if provided (used in hybrid architecture)
|
||||
positions: Optional position IDs (unused in Mamba)
|
||||
original_hidden_states: Optional original inputs (unused in Mamba)
|
||||
|
||||
|
||||
Returns:
|
||||
Transformed hidden states with residual connection applied
|
||||
"""
|
||||
@@ -555,7 +569,7 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
|
||||
class Zamba2HybridLayer(nn.Module):
|
||||
"""Hybrid layer combining Transformer and Mamba architectures.
|
||||
|
||||
|
||||
This layer implements the hybrid architecture described in the Zamba paper,
|
||||
where a shared transformer pathway processes input in parallel with a Mamba
|
||||
pathway. The transformer output is projected and added to the Mamba input
|
||||
@@ -573,22 +587,26 @@ class Zamba2HybridLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Initialize the hybrid layer.
|
||||
|
||||
|
||||
Args:
|
||||
shared_transformer: Transformer decoder layer for attention pathway
|
||||
"""
|
||||
super().__init__()
|
||||
self.block_idx = block_idx
|
||||
self.shared_transformer = shared_transformer
|
||||
self.linear = ReplicatedLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
self.linear = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -597,19 +615,19 @@ class Zamba2HybridLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the hybrid layer.
|
||||
|
||||
|
||||
Processes input through parallel transformer and Mamba paths:
|
||||
1. Transformer path processes input with attention
|
||||
2. Transformer output is projected to match hidden size
|
||||
3. Projected output is added to Mamba path input
|
||||
4. Final output combines both paths' representations
|
||||
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||
original_hidden_states: Original input for transformer residual
|
||||
original_hidden_states: Original input for transformer residual
|
||||
connection
|
||||
positions: Position IDs for positional embeddings
|
||||
|
||||
|
||||
Returns:
|
||||
Output tensor combining transformer and Mamba representations
|
||||
"""
|
||||
@@ -636,16 +654,16 @@ class Zamba2HybridLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class Zamba2Model(nn.Module):
|
||||
"""Core Zamba2 model combining transformer and Mamba architectures.
|
||||
|
||||
The model processes input through a sequence of hybrid and Mamba-only
|
||||
|
||||
The model processes input through a sequence of hybrid and Mamba-only
|
||||
layers, using token embeddings and final layer normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
"""Initialize the Zamba2 model.
|
||||
|
||||
|
||||
Args:
|
||||
vllm_config: Configuration object containing model, cache,
|
||||
vllm_config: Configuration object containing model, cache,
|
||||
quantization and LoRA settings
|
||||
prefix: Optional prefix for parameter names in state dict
|
||||
"""
|
||||
@@ -660,8 +678,11 @@ class Zamba2Model(nn.Module):
|
||||
assert not is_lora_enabled
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
@@ -679,15 +700,19 @@ class Zamba2Model(nn.Module):
|
||||
}
|
||||
|
||||
# Create cyclic iterator of transformer blocks
|
||||
blocks = cycle([
|
||||
Zamba2AttentionDecoderLayer(config,
|
||||
bare_block_idx=idx,
|
||||
num_hybrid_layers=len(layer2block_map),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}")
|
||||
for idx in range(config.num_mem_blocks)
|
||||
])
|
||||
blocks = cycle(
|
||||
[
|
||||
Zamba2AttentionDecoderLayer(
|
||||
config,
|
||||
bare_block_idx=idx,
|
||||
num_hybrid_layers=len(layer2block_map),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}",
|
||||
)
|
||||
for idx in range(config.num_mem_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# Initialize layers according to block type configuration
|
||||
layers = []
|
||||
@@ -699,32 +724,37 @@ class Zamba2Model(nn.Module):
|
||||
block = next(blocks)
|
||||
block_idx = layer2block_map[layer_idx]
|
||||
layers.append(
|
||||
Zamba2HybridLayer(block,
|
||||
config,
|
||||
block_idx,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
Zamba2HybridLayer(
|
||||
block,
|
||||
config,
|
||||
block_idx,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
Zamba2MambaDecoderLayer(config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix))
|
||||
Zamba2MambaDecoderLayer(
|
||||
config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
)
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
# Final layer normalization
|
||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert input token IDs to embeddings.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids: Tensor of input token IDs
|
||||
|
||||
|
||||
Returns:
|
||||
Embedded representation of the input tokens
|
||||
"""
|
||||
@@ -737,14 +767,14 @@ class Zamba2Model(nn.Module):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Forward pass through the model.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs for embeddings
|
||||
inputs_embeds: Optional pre-computed input embeddings
|
||||
|
||||
|
||||
Returns:
|
||||
Either final hidden states or intermediate tensors for pipeline
|
||||
Either final hidden states or intermediate tensors for pipeline
|
||||
parallelism
|
||||
"""
|
||||
# Handle pipeline parallelism for first rank
|
||||
@@ -765,8 +795,7 @@ class Zamba2Model(nn.Module):
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -780,8 +809,7 @@ class Zamba2Model(nn.Module):
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in chkpt_weight_name:
|
||||
continue
|
||||
chkpt_weight_name = chkpt_weight_name.replace(
|
||||
weight_name, param_name)
|
||||
chkpt_weight_name = chkpt_weight_name.replace(weight_name, param_name)
|
||||
param = params_dict[chkpt_weight_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -790,8 +818,7 @@ class Zamba2Model(nn.Module):
|
||||
if chkpt_weight_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[chkpt_weight_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(chkpt_weight_name)
|
||||
return loaded_params
|
||||
@@ -799,26 +826,28 @@ class Zamba2Model(nn.Module):
|
||||
|
||||
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"""Zamba2 model with causal language modeling head.
|
||||
|
||||
|
||||
This class wraps the core Zamba2 model and adds:
|
||||
- A language modeling head for next token prediction
|
||||
- Mamba state caching functionality
|
||||
- Support for model parallelism and quantization
|
||||
- Sampling capabilities for text generation
|
||||
"""
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
|
||||
"A_log": "A",
|
||||
"0.weight": "A.weight",
|
||||
"1.weight": "B.weight",
|
||||
})
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"A_log": "A",
|
||||
"0.weight": "A.weight",
|
||||
"1.weight": "B.weight",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
@@ -857,14 +886,14 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
"""Initialize the Zamba2 model for causal language modeling.
|
||||
|
||||
|
||||
Args:
|
||||
vllm_config: Configuration containing model, cache, quantization,
|
||||
LoRA and scheduler settings
|
||||
prefix: Optional prefix for parameter names
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If prefix caching is enabled
|
||||
AssertionError: If prefix caching is enabled
|
||||
(not supported by Mamba)
|
||||
"""
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -881,8 +910,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
# Initialize core model
|
||||
self.model = Zamba2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = Zamba2Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
# Initialize language modeling head
|
||||
self.lm_head = ParallelLMHead(
|
||||
@@ -892,15 +922,17 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Tie weights with input embeddings if using same dimensions
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
# Initialize logits processing and sampling
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert input token IDs to embeddings.
|
||||
@@ -911,19 +943,21 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
"""
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: Any) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the model.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs for embeddings
|
||||
inputs_embeds: Optional pre-computed input embeddings
|
||||
**kwargs: Additional arguments passed to cache manager
|
||||
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
@@ -951,7 +985,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user