Fix Plamo 2/3 & LFM2 for Transformers v5 (#38090)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -52,7 +52,7 @@ class Lfm2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ff_dim: int,
|
||||
intermediate_size: int,
|
||||
multiple_of: int,
|
||||
auto_adjust_ff_dim: bool,
|
||||
ffn_dim_multiplier: float | None,
|
||||
@@ -61,21 +61,23 @@ class Lfm2MLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
if auto_adjust_ff_dim:
|
||||
ff_dim = int(2 * ff_dim / 3)
|
||||
intermediate_size = int(2 * intermediate_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
ff_dim = int(ffn_dim_multiplier * ff_dim)
|
||||
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
||||
intermediate_size = int(ffn_dim_multiplier * intermediate_size)
|
||||
intermediate_size = multiple_of * (
|
||||
(intermediate_size + multiple_of - 1) // multiple_of
|
||||
)
|
||||
|
||||
self.w13 = MergedColumnParallelLinear(
|
||||
input_size=dim,
|
||||
output_sizes=[ff_dim] * 2,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w13",
|
||||
)
|
||||
self.w2 = RowParallelLinear(
|
||||
input_size=ff_dim,
|
||||
input_size=intermediate_size,
|
||||
output_size=dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
@@ -212,7 +214,7 @@ class Lfm2AttentionDecoderLayer(nn.Module):
|
||||
|
||||
self.feed_forward = Lfm2MLP(
|
||||
dim=config.block_dim,
|
||||
ff_dim=config.block_ff_dim,
|
||||
intermediate_size=config.intermediate_size,
|
||||
multiple_of=config.block_multiple_of,
|
||||
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
|
||||
ffn_dim_multiplier=config.block_ffn_dim_multiplier,
|
||||
@@ -262,7 +264,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
|
||||
|
||||
self.feed_forward = Lfm2MLP(
|
||||
dim=config.block_dim,
|
||||
ff_dim=config.block_ff_dim,
|
||||
intermediate_size=config.intermediate_size,
|
||||
multiple_of=config.block_multiple_of,
|
||||
auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
|
||||
ffn_dim_multiplier=config.block_ffn_dim_multiplier,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -71,30 +72,31 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
|
||||
# Only used for type hinting.
|
||||
class Plamo2Config(PretrainedConfig): # type: ignore
|
||||
model_type: str = "plamo2"
|
||||
if TYPE_CHECKING:
|
||||
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
rms_norm_eps: float
|
||||
# Attention
|
||||
num_attention_heads: int
|
||||
hidden_size_per_head: int
|
||||
num_key_value_heads: int
|
||||
# Mamba
|
||||
mamba_d_state: int
|
||||
mamba_d_conv: int
|
||||
mamba_num_heads: int
|
||||
mamba_step: int
|
||||
# MLP
|
||||
intermediate_size: int
|
||||
# Tokenizer
|
||||
vocab_size: int
|
||||
class Plamo2Config(PretrainedConfig): # type: ignore
|
||||
model_type: str = "plamo2"
|
||||
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
rms_norm_eps: float
|
||||
# Attention
|
||||
num_attention_heads: int
|
||||
hidden_size_per_head: int
|
||||
num_key_value_heads: int
|
||||
# Mamba
|
||||
mamba_d_state: int
|
||||
mamba_d_conv: int
|
||||
mamba_num_heads: int
|
||||
mamba_step: int
|
||||
# MLP
|
||||
intermediate_size: int
|
||||
# Tokenizer
|
||||
vocab_size: int
|
||||
|
||||
|
||||
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||
def is_mamba(config: "Plamo2Config", i: int) -> bool:
|
||||
assert config.mamba_step > 1
|
||||
|
||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||
@@ -502,7 +504,7 @@ direct_register_custom_op(
|
||||
class DenseMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Plamo2Config,
|
||||
config: "Plamo2Config",
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -46,28 +46,29 @@ from vllm.model_executor.models.utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
# Only used for type hinting.
|
||||
class Plamo3Config(PretrainedConfig): # type: ignore
|
||||
model_type: str = "plamo3"
|
||||
if TYPE_CHECKING:
|
||||
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
rms_norm_eps: float
|
||||
# Attention
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
num_key_value_heads: int
|
||||
# vllm rename `sliding_window` attr to `interleaved_sliding_window`
|
||||
# if `sliding_window` is list
|
||||
interleaved_sliding_window: list[int | None]
|
||||
sliding_window_pattern: int
|
||||
rope_parameters: dict[str, Any]
|
||||
rope_local_theta: int
|
||||
# MLP
|
||||
intermediate_size: int
|
||||
# Tokenizer
|
||||
vocab_size: int
|
||||
class Plamo3Config(PretrainedConfig): # type: ignore
|
||||
model_type: str = "plamo3"
|
||||
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
rms_norm_eps: float
|
||||
# Attention
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
num_key_value_heads: int
|
||||
# vllm rename `sliding_window` attr to `interleaved_sliding_window`
|
||||
# if `sliding_window` is list
|
||||
interleaved_sliding_window: list[int | None]
|
||||
sliding_window_pattern: int
|
||||
rope_parameters: dict[str, Any]
|
||||
rope_local_theta: int
|
||||
# MLP
|
||||
intermediate_size: int
|
||||
# Tokenizer
|
||||
vocab_size: int
|
||||
|
||||
|
||||
def rms_norm_weight_loader(offset: float) -> LoaderFunction:
|
||||
@@ -80,7 +81,7 @@ def rms_norm_weight_loader(offset: float) -> LoaderFunction:
|
||||
class DenseMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Plamo3Config,
|
||||
config: "Plamo3Config",
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user