[Models] Step-3.5-Flash (#33523)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: i-zhangmingming <i-zhangmingming@stepfun.com>
Co-authored-by: xiewuxun <xiewuxun@stepfun.com>
Co-authored-by: zetaohong <i-hongzetao@stepfun.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
(cherry picked from commit c3b40dc3e7)
This commit is contained in:
csy0225
2026-02-02 10:21:18 +08:00
committed by khluu
parent c7039a80b8
commit 8b45c58fe9
18 changed files with 3110 additions and 4 deletions

View File

@@ -456,6 +456,7 @@ th {
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ |
| `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ |
| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ |
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ |
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ |

View File

@@ -17,6 +17,8 @@ from vllm.model_executor.layers.activation import (
QuickGELU,
SiluAndMul,
SwigluOAIAndMul,
SwigluStepAndMul,
swiglustep_and_mul_triton,
)
from vllm.utils.torch_utils import set_random_seed
@@ -36,6 +38,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
"gelu_tanh",
"fatrelu",
"swigluoai_and_mul",
"swiglustep_and_mul",
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -75,9 +78,12 @@ def test_act_and_mul(
elif activation == "swigluoai_and_mul":
layer = SwigluOAIAndMul()
fn = torch.ops._C.swigluoai_and_mul
elif activation == "swiglustep_and_mul":
layer = SwigluStepAndMul()
fn = swiglustep_and_mul_triton
out = layer(x)
ref_out = layer.forward_native(x)
if activation == "swigluoai_and_mul":
if activation in ["swigluoai_and_mul", "swiglustep_and_mul"]:
rtol = {
# For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16: 2e-3,
@@ -104,7 +110,7 @@ def test_act_and_mul(
opcheck(fn, (out, x, threshold))
elif activation == "swigluoai_and_mul":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
elif activation != "swiglustep_and_mul":
opcheck(fn, (out, x))

View File

@@ -480,6 +480,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Step1ForCausalLM": _HfExamplesInfo(
"stepfun-ai/Step-Audio-EditX", trust_remote_code=True
),
"Step3p5ForCausalLM": _HfExamplesInfo(
"stepfun-ai/step-3.5-flash", is_available_online=False
),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
@@ -1081,6 +1084,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Qwen3NextMTP": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"
),
"Step3p5MTP": _HfExamplesInfo(
"stepfun-ai/Step-3.5-Flash",
trust_remote_code=True,
speculative_model="stepfun-ai/Step-3.5-Flash",
is_available_online=False,
),
}
_TRANSFORMERS_BACKEND_MODELS = {

View File

@@ -40,6 +40,7 @@ MTPModelTypes = Literal[
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
@@ -252,6 +253,11 @@ class SpeculativeConfig:
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
if hf_config.model_type == "step3p5":
hf_config.model_type = "step3p5_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

View File

@@ -17,11 +17,63 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.collection_utils import LazyDict
logger = init_logger(__name__)
@triton.jit
def _swiglustep_and_mul_kernel(
o_ptr,
o_stride,
x_ptr,
x_stride,
limit: tl.constexpr,
d: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
i = tl.program_id(axis=0).to(tl.int64)
j = tl.program_id(axis=1)
o_row_ptr = o_ptr + o_stride * i
x_row_ptr = x_ptr + x_stride * i
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d
gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
gate_silu = tl.sigmoid(gate) * gate
gate_clamped = tl.minimum(gate_silu, limit)
up_clamped = tl.minimum(tl.maximum(up, -limit), limit)
result = gate_clamped * up_clamped
result = result.to(x_ptr.dtype.element_ty)
tl.store(o_row_ptr + offsets, result, mask=mask)
def swiglustep_and_mul_triton(
output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
b, n = input.shape
assert input.ndim == 2
assert n % 2 == 0
d = n // 2
def grid(meta):
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
_swiglustep_and_mul_kernel[grid](
output,
output.stride(0),
input,
input.stride(0),
limit=limit,
d=d,
BLOCK_SIZE=1024,
)
# --8<-- [start:fatrelu_and_mul]
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
@@ -304,6 +356,44 @@ class SwigluOAIAndMul(CustomOp):
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
"""An activation function for SwiGLU with clamping.
Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, limit: float = 7.0):
super().__init__()
if limit is None:
raise ValueError("SwigluStepAndMul requires limit to be set.")
self.limit = limit
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return gate * up
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
swiglustep_and_mul_triton(out, x, self.limit)
return out
def extra_repr(self) -> str:
return f"limit={repr(self.limit)}"
# --8<-- [start:gelu_new]
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):

View File

@@ -144,7 +144,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
return activation in ["silu", "swiglustep"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:

View File

@@ -1949,7 +1949,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
return activation in ["silu", "gelu", "swigluoai", "swiglustep"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:

View File

@@ -358,6 +358,11 @@ def apply_moe_activation(
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))

View File

@@ -188,6 +188,7 @@ _TEXT_GENERATION_MODELS = {
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
@@ -476,6 +477,7 @@ _SPECULATIVE_DECODING_MODELS = {
"MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),

View File

@@ -0,0 +1,894 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from collections.abc import Iterable
from typing import Any
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class FP32ReplicatedLinear(ReplicatedLinear):
"""
Use FP32 for higher precision.
"""
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
assert self.params_dtype == torch.float32
return super().forward(x.to(torch.float32))
class Step3p5MLP(nn.Module):
def __init__(
self,
config: ModelConfig,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
self.prefix = prefix
self.hidden_size = hidden_size
self.limit = None
layer_idx = extract_layer_index(prefix)
if (
config.swiglu_limits_shared
and config.swiglu_limits_shared[layer_idx] is not None
and config.swiglu_limits_shared[layer_idx] != 0
):
self.limit = config.swiglu_limits_shared[layer_idx]
self.act_fn = SwigluStepAndMul(limit=self.limit)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
intermediate_act = self.act_fn(gate_up)
output, _ = self.down_proj(intermediate_act)
return output
class Step3p5Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float | list[float] | None = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
rope_scaling: dict[str, Any] | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
# Step3p5 specific args
sliding_window: int | None = None,
use_head_wise_attn_gate: bool = False,
layer_types: list = None,
use_rope_layers: list = None,
yarn_only_types: list = None,
swa_num_attention_heads: int | None = None,
partial_rotary_factor: float = 1.0,
):
super().__init__()
self.hidden_size = hidden_size
self.total_num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
self.layer_idx = extract_layer_index(prefix)
if layer_types:
enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention"
else:
enable_sliding_window = self.layer_idx % 2 == 0
if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types:
rope_scaling = None
if sliding_window is not None and enable_sliding_window:
sliding_window = sliding_window
if swa_num_attention_heads is not None:
num_heads = swa_num_attention_heads
self.total_num_heads = swa_num_attention_heads
else:
sliding_window = None
if isinstance(rope_theta, list):
rope_theta = rope_theta[self.layer_idx]
self.rank = get_tensor_model_parallel_rank()
self.partial_rotary_factor = partial_rotary_factor
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
if rope_scaling is not None and not isinstance(rope_scaling, dict):
raise ValueError("rope_scaling must be a dict for Step3p5Attention.")
rope_parameters: dict[str, Any] = (
dict(rope_scaling) if rope_scaling is not None else {}
)
rope_parameters.setdefault("rope_type", "default")
rope_parameters["rope_theta"] = self.rope_theta
rope_parameters["partial_rotary_factor"] = partial_rotary_factor
self.rotary_emb = get_rope(
head_size=self.head_dim,
max_position=max_position,
rope_parameters=rope_parameters,
)
self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
self.use_head_wise_attn_gate = use_head_wise_attn_gate
if use_head_wise_attn_gate:
self.g_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads,
bias=False,
prefix=f"{prefix}.g_proj",
)
self.use_rope = True
if use_rope_layers:
self.use_rope = use_rope_layers[self.layer_idx]
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
)
self.max_position_embeddings = max_position
assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
self.rotary_dim = (
self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm inline similar to Qwen3 MOE attention
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_by_head = self.q_norm(q_by_head.contiguous())
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head.contiguous())
k = k_by_head.view(k.shape)
if self.use_rope:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if self.use_head_wise_attn_gate:
extra_dims, _ = self.g_proj(hidden_states)
output = (
attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim)
* extra_dims.unsqueeze(-1).sigmoid()
)
attn_output = output.view(*attn_output.shape)
output, _ = self.o_proj(attn_output)
return output
class FusedMoEBlock(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_idx = extract_layer_index(prefix)
self.ep_size = get_ep_group().device_group.size()
self.ep_rank = get_ep_group().device_group.rank()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
self.enable_eplb = parallel_config.enable_eplb
self.n_routed_experts = config.moe_num_experts
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}."
)
self.gate = FP32ReplicatedLinear(
config.hidden_size,
config.moe_num_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32, # Use FP32 for higher precision.
prefix=f"{prefix}.gate",
)
self.use_moe_router_bias = config.use_moe_router_bias
assert self.use_moe_router_bias, "Only support use_moe_router_bias is true."
self.routed_scaling_factor = config.moe_router_scaling_factor
self.router_bias = nn.Parameter(
torch.zeros(config.moe_num_experts, dtype=torch.float32),
requires_grad=False,
)
self.need_fp32_gate = config.need_fp32_gate
assert self.need_fp32_gate, (
"Router logits must use FP32 precision for numerical stability."
)
activation = "silu"
swiglu_limits = config.swiglu_limits or []
swiglu_limit = (
swiglu_limits[self.layer_idx]
if self.layer_idx < len(swiglu_limits)
else None
)
if swiglu_limit not in (None, 0):
swiglu_limit = float(swiglu_limit)
assert swiglu_limit == 7.0, (
"Swiglu limit in fused moe block only suport 7.0 now."
)
activation = "swiglustep"
logger.debug(
"step3p5 layer_idx: %s, activation: %s, limit: %s",
self.layer_idx,
activation,
swiglu_limit,
)
self.share_expert = Step3p5MLP(
config=config,
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
reduce_results=False,
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
)
self.experts = SharedFusedMoE(
shared_experts=self.share_expert,
gate=self.gate,
num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
activation=activation,
prefix=f"{prefix}.experts",
scoring_func=getattr(config, "moe_router_activation", "sigmoid"),
e_score_correction_bias=self.router_bias,
routed_scaling_factor=config.moe_router_scaling_factor,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.share_expert is None:
assert shared_output is None
if self.share_expert is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
class Step3p5DecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
if cache_config is not None:
cache_config.sliding_window = None
if config.att_impl_type == "GQA":
num_attention_heads = None
num_attention_groups = None
head_dim = None
if (
getattr(config, "attention_other_setting", None)
and getattr(config, "layer_types", [])
and config.layer_types[layer_idx]
== config.attention_other_setting["attention_type"]
):
num_attention_heads = config.attention_other_setting[
"num_attention_heads"
]
num_attention_groups = config.attention_other_setting[
"num_attention_groups"
]
head_dim = config.attention_other_setting["head_dim"]
partial_rotary_factors = getattr(config, "partial_rotary_factors", [])
self.self_attn = Step3p5Attention(
hidden_size=self.hidden_size,
num_heads=num_attention_heads
if num_attention_heads
else config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=num_attention_groups
if num_attention_groups
else config.num_attention_groups,
rope_theta=config.rope_theta,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=head_dim if head_dim else getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=getattr(config, "rope_scaling", None),
sliding_window=getattr(config, "sliding_window", None),
use_head_wise_attn_gate=getattr(
config, "use_head_wise_attn_gate", False
),
layer_types=getattr(config, "layer_types", []),
use_rope_layers=getattr(config, "use_rope_layers", []),
yarn_only_types=getattr(config, "yarn_only_types", []),
partial_rotary_factor=partial_rotary_factors[layer_idx]
if partial_rotary_factors
else 1.0,
prefix=f"{prefix}.self_attn",
)
else:
raise ValueError(
f"Unsupported attention implementation: {config.att_impl_type}"
)
self.use_moe = False
self.tp_group = get_tp_group()
self.use_fused_all_reduce = (
get_tensor_model_parallel_world_size() > 1
and get_dp_group().world_size == 1
)
if self.use_fused_all_reduce:
logger.warning_once("Enable custom fused all reduce...")
else:
logger.warning_once("Disable custom fused all reduce...")
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(
vllm_config,
prefix=f"{prefix}.moe",
)
self.use_moe = True
else:
self.mlp = Step3p5MLP(
config=config,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
reduce_results=True,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, config.rms_norm_eps
)
self.prefix = prefix
def add_and_maybe_inplace_all_reduce(
self, in1: torch.Tensor, in2: torch.Tensor
) -> torch.Tensor:
if not self.use_fused_all_reduce:
return in1 + in2
return self.tp_group.all_reduce(in1 + in2)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states += residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.use_moe:
ffn_output = self.moe(hidden_states)
else:
ffn_output = self.mlp(hidden_states)
hidden_states = ffn_output + residual
return hidden_states
@support_torch_compile
class Step3p5Model(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
self.vocab_size = config.vocab_size
self.config = config
self.moe_num_experts = config.moe_num_experts
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step3p5DecoderLayer(
vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
}
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
config = self.config
assert config.num_attention_groups > 1, "Only support GQA"
qkv_params_mapping = []
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
]
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
if name.startswith("model."):
local_name = name[len("model.") :]
full_name = name
else:
local_name = name
full_name = f"model.{name}" if name else "model"
spec_layer = get_spec_layer_idx_from_weight_name(config, full_name)
if spec_layer is not None:
continue # skip spec decode layers for main model
# Skip any layers beyond the main model's depth (e.g., MTP layers)
if full_name.startswith("model.layers."):
parts = full_name.split(".")
if len(parts) > 2 and parts[2].isdigit():
layer_idx = int(parts[2])
if layer_idx >= config.num_hidden_layers:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in local_name:
continue
if any(
disable_moe_stacked_param in local_name
for disable_moe_stacked_param in disable_moe_stacked_params
):
continue
replaced_name = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(replaced_name, self):
continue
if replaced_name not in params_dict:
continue
param = params_dict[replaced_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(replaced_name)
break
else:
for param_name, weight_name, shard_id in expert_params_mapping:
if weight_name not in local_name:
continue
replaced_name = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(replaced_name, self):
continue
if (
replaced_name.endswith(".bias")
or replaced_name.endswith("_bias")
) and replaced_name not in params_dict:
continue
if replaced_name not in params_dict:
continue
param = params_dict[replaced_name]
weight_loader = param.weight_loader
moe_expert_num = self.moe_num_experts
assert loaded_weight.shape[0] == moe_expert_num
for expert_id in range(moe_expert_num):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(
param,
loaded_weight_expert,
replaced_name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(replaced_name)
break
else:
for (
param_name,
weight_name,
start_idx,
end_idx,
) in qkv_params_mapping:
if weight_name not in local_name:
continue
replaced_name = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(replaced_name, self):
continue
if replaced_name not in params_dict:
continue
param = params_dict[replaced_name]
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(
param.output_dim, begin_idx, end_idx - begin_idx
)
param_slice.copy_(loaded_weight)
loaded_params.add(replaced_name)
break
else:
if is_pp_missing_parameter(local_name, self):
continue
if "expert_bias" in local_name:
logger.warning_once("ignore expert_bias")
continue
if local_name not in params_dict:
continue
param = params_dict[local_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(local_name)
return loaded_params
class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".share_expert.": ".moe.share_expert."}
)
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.vllm_config = vllm_config
self.model = Step3p5Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.moe_layers: list[FusedMoEBlock] = []
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Step3p5DecoderLayer)
if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock):
self.moe_layers.append(layer.moe)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config
else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self.expert_weights = []
assert len(self.moe_layers) > 0, "No MoE layers found in the model."
example_layer = self.moe_layers[0]
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
):
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.model.norm(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
experts = layer.experts
assert isinstance(experts, FusedMoE)
# Register the expert weights.
self.expert_weights.append(experts.get_expert_weights())
experts.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.moe_layers:
assert isinstance(layer, FusedMoEBlock)
layer.n_local_physical_experts = num_local_physical_experts
layer.n_physical_experts = num_physical_experts
layer.n_redundant_experts = self.num_redundant_experts
layer.experts.update_expert_map()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_spec_layer_idx_from_weight_name(
config: ModelConfig, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(
f"layers.{layer_idx + i}." # Step3p5Model
) or weight_name.startswith(f"model.layers.{layer_idx + i}."): # Step3p5MTP
return layer_idx + i
return None

View File

@@ -0,0 +1,315 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
from .utils import maybe_prefix
logger = init_logger(__name__)
class SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
return hidden_states
class Step3p5AMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
)
return logits
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
class Step3p5MTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model = Step3p5AMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if "embed_tokens" not in name and spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if "experts" in name or "moe" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(
param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(name)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
or "tok_embeddings" in name
):
continue
if spec_layer is not None and ".transformer." in name:
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
and self.config.num_nextn_predict_layers > 0
)
name = "model.embed_tokens.weight"
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
params_need_to_load = set(params_dict.keys())
# Some KV cache scales are optional: checkpoints may omit them and vLLM
# will fall back to default scales during initialization.
optional_params = {
name
for name, param in params_dict.items()
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale"))
and getattr(param, "numel", lambda: 0)() == 1
and getattr(param, "requires_grad", False) is False
}
params_need_to_load -= optional_params
if params_need_to_load != loaded_params:
missing_params = list(params_need_to_load - loaded_params)
param_name_example = missing_params[0]
raise RuntimeError(
"Some parameters like "
f"{param_name_example} are not in the checkpoint and will falsely "
"use random initialization"
)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names = [
"embed_tokens",
"enorm",
"hnorm",
"eh_proj",
"shared_head",
]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
return name

View File

@@ -84,6 +84,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"step3_reasoning_parser",
"Step3ReasoningParser",
),
"step3p5": (
"step3p5_reasoning_parser",
"Step3p5ReasoningParser",
),
}

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
class Step3p5ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Step3p5 model.
Step3p5 uses the <think>...</think> format, but it tends to emit an extra
newline immediately before and/or after the </think> token. This parser trims:
- the newline right before </think>
- the newline right after </think>
"""
@property
def start_token(self) -> str:
return "<think>"
@property
def end_token(self) -> str:
return "</think>"
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Used to hold a trailing "\n" from reasoning content so we can decide
# whether it is immediately before </think>.
self._pending_reasoning_newline = False
# Used to delay the reasoning end detection.
# This is necessary to remove the newline appears immediately after </think>,
# which may cause the end detection to be delayed by one round.
self.end_offset = 1
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0:
self.end_offset -= 1
return False
return self.end_offset < 1
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0:
self.end_offset -= 1
return False
return self.end_offset < 1
def extract_reasoning(
self,
model_output: str,
request: ChatCompletionRequest | ResponsesRequest,
) -> tuple[str | None, str | None]:
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = reasoning.removesuffix("\n")
if content is not None:
content = content.removeprefix("\n")
return reasoning or None, content or None
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
# Drop the immediate newline that models often emit after </think>.
if previous_text.endswith(self.end_token) and delta_text:
if delta_text == "\n":
return None
elif delta_text.startswith("\n"):
remaining = delta_text.removeprefix("\n")
return DeltaMessage(content=remaining) if remaining else None
ret = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if ret is None:
return None
# Compatibility path for models that don't generate the start token:
# treat everything before </think> as reasoning and everything after
# as content.
if (
self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids
):
if self.end_token_id in delta_token_ids:
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
ret = DeltaMessage(reasoning=reasoning, content=content or None)
elif self.end_token_id in previous_token_ids:
ret = DeltaMessage(content=delta_text)
else:
ret = DeltaMessage(reasoning=delta_text)
reasoning_to_output = ret.reasoning
content_to_output = ret.content
# Reasoning: handle the newline immediately before </think>.
if reasoning_to_output is not None:
if self._pending_reasoning_newline:
reasoning_to_output = "\n" + reasoning_to_output
self._pending_reasoning_newline = False
if reasoning_to_output.endswith("\n"):
reasoning_to_output = reasoning_to_output.removesuffix("\n")
if self.end_token in delta_text:
# Trailing "\n" is right before </think>, drop it.
self._pending_reasoning_newline = False
else:
# Hold the trailing "\n" until we know whether </think> follows.
self._pending_reasoning_newline = True
# Content: handle the newline immediately after </think>.
if content_to_output is not None:
# No need to get into parser again to remove newline after </think>.
self.end_offset -= 1
# If we have content, reasoning must have ended.
self._pending_reasoning_newline = False
if self.end_token in delta_text and content_to_output.startswith("\n"):
content_to_output = content_to_output.removeprefix("\n")
reasoning_to_output = reasoning_to_output or None
content_to_output = content_to_output or None
if reasoning_to_output is None and content_to_output is None:
return None
return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output)

View File

@@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"step3_tool_parser",
"Step3ToolParser",
),
"step3p5": (
"step3p5_tool_parser",
"Step3p5ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",

File diff suppressed because it is too large Load Diff

View File

@@ -96,6 +96,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
step3_text="Step3TextConfig",
step3p5="Step3p5Config",
qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig",
lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config",

View File

@@ -50,6 +50,8 @@ _CLASS_TO_MODULE: dict[str, str] = {
"Step3VLConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3TextConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3p5Config": "vllm.transformers_utils.configs.step3p5",
"Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr",
"Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next",
"Tarsier2Config": "vllm.transformers_utils.configs.tarsier2",
# Special case: DeepseekV3Config is from HuggingFace Transformers
@@ -90,6 +92,8 @@ __all__ = [
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
"Step3p5Config",
"Qwen3ASRConfig",
"Qwen3NextConfig",
"Tarsier2Config",
]

View File

@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers.configuration_utils import PretrainedConfig
class Step3p5Config(PretrainedConfig):
model_type = "step3p5"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
moe_every_n_layer: int = 2,
use_moe: bool = False,
moe_intermediate_size: int = 10240,
moe_num_experts: int = 16,
moe_top_k: int = 4,
moe_layer_offset: int = 0,
rope_theta: float | list[float] | None = 500000,
rope_scaling: dict[str, Any] | None = None,
head_dim: int | None = None,
share_expert_dim: int | None = None,
norm_expert_weight: bool = True,
bos_token_id: list[int] | int | None = None,
eos_token_id: list[int] | int | None = None,
moe_router_activation: str = "softmax",
moe_router_scaling_factor: float = 1.0,
att_impl_type: str = "GQA",
use_head_wise_attn_gate: bool = False,
use_moe_router_bias: bool = True,
need_fp32_gate: bool = True,
layer_types: list[str] | None = None,
use_rope_layers: list[bool] | None = None,
yarn_only_types: list[str] | None = None,
attention_other_setting: dict[str, Any] | None = None,
num_nextn_predict_layers: int = 0,
swiglu_limits: list[float] | None = None,
swiglu_limits_shared: list[float] | None = None,
max_position_embeddings: int | None = None,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.use_moe = use_moe
self.moe_intermediate_size = moe_intermediate_size
self.moe_every_n_layer = moe_every_n_layer
self.moe_num_experts = moe_num_experts
self.num_experts_per_tok = moe_top_k
self.moe_top_k = moe_top_k
self.moe_layer_offset = moe_layer_offset
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.head_dim = head_dim
if share_expert_dim is None:
self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k
else:
self.share_expert_dim = share_expert_dim
self.norm_expert_weight = norm_expert_weight
self.max_position_embeddings = max_position_embeddings
self.moe_router_activation = moe_router_activation
self.moe_router_scaling_factor = moe_router_scaling_factor
self.use_moe_router_bias = use_moe_router_bias
self.need_fp32_gate = need_fp32_gate
self.att_impl_type = att_impl_type
self.use_head_wise_attn_gate = use_head_wise_attn_gate
self.layer_types = layer_types
self.use_rope_layers = use_rope_layers
self.yarn_only_types = yarn_only_types
self.attention_other_setting = attention_other_setting
self.num_nextn_predict_layers = num_nextn_predict_layers
self.swiglu_limits = swiglu_limits
self.swiglu_limits_shared = swiglu_limits_shared
resolved_bos_token_id = 1 if bos_token_id is None else bos_token_id
resolved_eos_token_id = [2, 3] if eos_token_id is None else eos_token_id
self.bos_token_id = resolved_bos_token_id
self.eos_token_id = resolved_eos_token_id
super().__init__(
bos_token_id=resolved_bos_token_id,
eos_token_id=resolved_eos_token_id,
**kwargs,
)