[GLM-OCR] GLM-OCR with MTP Support (#33005)

Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Yuxuan Zhang
2026-01-26 22:24:43 +08:00
committed by GitHub
parent dcd80206b7
commit bb17e8f11c
14 changed files with 873 additions and 8 deletions

View File

@@ -39,13 +39,22 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsLoRA, SupportsPP
from .llama import LlamaMLP as Glm4MLP
from .llama import LlamaModel
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
maybe_prefix,
)
class Glm4Attention(nn.Module):
@@ -78,7 +87,15 @@ class Glm4Attention(nn.Module):
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
rope_params = getattr(config, "rope_parameters", None)
if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params:
config.rope_parameters.setdefault(
"partial_rotary_factor", rope_params["partial_rotary_factor"]
)
else:
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
@@ -220,6 +237,73 @@ class Glm4Model(LlamaModel):
vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
@@ -293,3 +377,16 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
def get_spec_layer_idx_from_weight_name(
config: Glm4Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if f"layers.{layer_idx + i}." in weight_name:
return layer_idx + i
return None

View File

@@ -24,7 +24,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
"""Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model
compatible with HuggingFace weights."""
import itertools
import math
@@ -1418,7 +1419,7 @@ class Glm4vForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
)
if config.model_type == "glm4v":
if config.model_type in ("glm4v", "glm_ocr"):
architectures = ["Glm4ForCausalLM"]
elif config.model_type == "glm4v_moe":
architectures = ["Glm4MoeForCausalLM"]

View File

@@ -0,0 +1,389 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2026 The ZhipuAI Team.
# Copyright 2026 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-OCR model compatible with HuggingFace weights."""
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from einops import rearrange
if TYPE_CHECKING:
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.models.glm4_1v import (
Glm4vDummyInputsBuilder,
Glm4vForConditionalGeneration,
Glm4vMultiModalProcessor,
Glm4vPatchMerger,
Glm4vProcessingInfo,
Glm4vVisionBlock,
Glm4vVisionMLP,
Glm4vVisionPatchEmbed,
Glm4vVisionTransformer,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from .utils import (
maybe_prefix,
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
)
logger = init_logger(__name__)
class GlmOcrVisionMLP(Glm4vVisionMLP):
pass
class GlmOcrVisionAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.tp_rank = (
0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank()
)
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size
)
self.head_dim = embed_dim // num_heads
self.q_norm = RMSNorm(self.head_dim, eps=1e-5)
self.k_norm = RMSNorm(self.head_dim, eps=1e-5)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
disable_tp=use_data_parallel,
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=True,
disable_tp=use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (
seq_len,
bs,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
# RMSNorm on q, k
q_shape, k_shape = q.shape, k.shape
q = self.q_norm(q.reshape(-1, self.head_dim)).view(q_shape)
k = self.k_norm(k.reshape(-1, self.head_dim)).view(k_shape)
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class GlmOcrVisionBlock(Glm4vVisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(
dim,
num_heads,
mlp_hidden_dim,
norm_layer,
quant_config,
prefix,
)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = GlmOcrVisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.mlp = GlmOcrVisionMLP(
dim,
mlp_hidden_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
class GlmOcrVisionPatchEmbed(Glm4vVisionPatchEmbed):
pass
class GlmOcrPatchMerger(Glm4vPatchMerger):
pass
class GlmOcrVisionTransformer(Glm4vVisionTransformer):
def __init__(
self,
vision_config: GlmOcrVisionConfig,
norm_eps: float = 1e-5,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix)
del self.post_conv_layernorm
del self.embeddings
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.out_hidden_size = vision_config.out_hidden_size
self.patch_embed = Glm4vVisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.blocks = nn.ModuleList(
[
GlmOcrVisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
]
)
self.merger = GlmOcrPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=vision_config.out_hidden_size * vision_config.in_channels,
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.merger",
)
self.downsample = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=vision_config.out_hidden_size,
kernel_size=vision_config.spatial_merge_size,
stride=vision_config.spatial_merge_size,
)
self.post_layernorm = RMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
)
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
grid_thw
)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
)
# adapter
x = self.post_layernorm(x)
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
x = x.permute(0, 3, 1, 2)
x = self.downsample(x).view(-1, self.out_hidden_size)
x = self.merger(x)
return x
@MULTIMODAL_REGISTRY.register_processor(
Glm4vMultiModalProcessor,
info=Glm4vProcessingInfo,
dummy_inputs=Glm4vDummyInputsBuilder,
)
class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = GlmOcrVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)

View File

@@ -0,0 +1,285 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2026 The ZhipuAI Team.
# Copyright 2026 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-OCR MTP model compatible with HuggingFace weights."""
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .glm4 import Glm4DecoderLayer, get_spec_layer_idx_from_weight_name
from .glm4_moe_lite_mtp import (
Glm4MoeLiteMultiTokenPredictor,
SharedHead,
)
from .interfaces import SupportsPP
from .utils import (
is_pp_missing_parameter,
maybe_prefix,
)
class GlmOcrMultiTokenPredictorLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.speculative_config.draft_model_config.hf_config.text_config
self.config = config
quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.device = current_platform.device_type
self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config
)
self.mtp_block = Glm4DecoderLayer(
vllm_config=vllm_config, prefix=prefix, config=self.config
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions[0] == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
hidden_states, residual = self.mtp_block(
positions=positions, hidden_states=hidden_states, residual=None
)
hidden_states = residual + hidden_states
return hidden_states
class GlmOcrMultiTokenPredictor(Glm4MoeLiteMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config.text_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
self.layers = torch.nn.ModuleDict(
{
str(idx): GlmOcrMultiTokenPredictorLayer(
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
class GlmOcrMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.model = GlmOcrMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.expert_weights = []
self.num_layers = self.config.num_nextn_predict_layers
for layer in self.model.layers.values():
assert isinstance(layer, GlmOcrMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, Glm4DecoderLayer)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
spec_layer = self.model.mtp_start_layer_idx
name = f"model.layers.{spec_layer}.shared_head.head.weight"
elif name == "model.embed_tokens.weight":
spec_layer = self.model.mtp_start_layer_idx
else:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Some checkpoints include weight scale tensors for the
# LM head even when the quantized head isn't built. Skip
# them if the model does not expose a matching parameter
# to avoid KeyError during load.
if name.endswith(".weight_scale") and name not in params_dict:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
name = name.replace("model.language_model.layers", "model.layers")
spec_layer_weight_names = [
"embed_tokens",
"enorm",
"hnorm",
"eh_proj",
"shared_head",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name

View File

@@ -319,8 +319,9 @@ _MULTIMODAL_MODELS = {
),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
"GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": (
"granite_speech",
"GraniteSpeechForConditionalGeneration",
@@ -472,6 +473,7 @@ _SPECULATIVE_DECODING_MODELS = {
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
"GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
"MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),