[GLM-OCR] GLM-OCR with MTP Support (#33005)
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -39,13 +39,22 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .llama import LlamaMLP as Glm4MLP
|
||||
from .llama import LlamaModel
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class Glm4Attention(nn.Module):
|
||||
@@ -78,7 +87,15 @@ class Glm4Attention(nn.Module):
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
|
||||
|
||||
rope_params = getattr(config, "rope_parameters", None)
|
||||
if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params:
|
||||
config.rope_parameters.setdefault(
|
||||
"partial_rotary_factor", rope_params["partial_rotary_factor"]
|
||||
)
|
||||
else:
|
||||
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
|
||||
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
@@ -220,6 +237,73 @@ class Glm4Model(LlamaModel):
|
||||
vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
@@ -293,3 +377,16 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
config: Glm4Config, weight_name: str
|
||||
) -> int | None:
|
||||
if hasattr(config, "num_nextn_predict_layers") and (
|
||||
config.num_nextn_predict_layers > 0
|
||||
):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if f"layers.{layer_idx + i}." in weight_name:
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
||||
@@ -24,7 +24,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
||||
"""Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model
|
||||
compatible with HuggingFace weights."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
@@ -1418,7 +1419,7 @@ class Glm4vForConditionalGeneration(
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
if config.model_type == "glm4v":
|
||||
if config.model_type in ("glm4v", "glm_ocr"):
|
||||
architectures = ["Glm4ForCausalLM"]
|
||||
elif config.model_type == "glm4v_moe":
|
||||
architectures = ["Glm4MoeForCausalLM"]
|
||||
|
||||
389
vllm/model_executor/models/glm_ocr.py
Normal file
389
vllm/model_executor/models/glm_ocr.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
|
||||
# Copyright 2026 The ZhipuAI Team.
|
||||
# Copyright 2026 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-OCR model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
||||
MMEncoderAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.models.glm4_1v import (
|
||||
Glm4vDummyInputsBuilder,
|
||||
Glm4vForConditionalGeneration,
|
||||
Glm4vMultiModalProcessor,
|
||||
Glm4vPatchMerger,
|
||||
Glm4vProcessingInfo,
|
||||
Glm4vVisionBlock,
|
||||
Glm4vVisionMLP,
|
||||
Glm4vVisionPatchEmbed,
|
||||
Glm4vVisionTransformer,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
get_vit_attn_backend,
|
||||
is_vit_use_data_parallel,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GlmOcrVisionMLP(Glm4vVisionMLP):
|
||||
pass
|
||||
|
||||
|
||||
class GlmOcrVisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.tp_rank = (
|
||||
0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank()
|
||||
)
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size
|
||||
)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=1e-5)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=1e-5)
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
bias=True,
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
)
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (
|
||||
seq_len,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb_cos: torch.Tensor,
|
||||
rotary_pos_emb_sin: torch.Tensor,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
|
||||
# RMSNorm on q, k
|
||||
q_shape, k_shape = q.shape, k.shape
|
||||
q = self.q_norm(q.reshape(-1, self.head_dim)).view(q_shape)
|
||||
k = self.k_norm(k.reshape(-1, self.head_dim)).view(k_shape)
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
context_layer = self.attn(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
class GlmOcrVisionBlock(Glm4vVisionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_hidden_dim,
|
||||
norm_layer,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.attn = GlmOcrVisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.mlp = GlmOcrVisionMLP(
|
||||
dim,
|
||||
mlp_hidden_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
|
||||
class GlmOcrVisionPatchEmbed(Glm4vVisionPatchEmbed):
|
||||
pass
|
||||
|
||||
|
||||
class GlmOcrPatchMerger(Glm4vPatchMerger):
|
||||
pass
|
||||
|
||||
|
||||
class GlmOcrVisionTransformer(Glm4vVisionTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: GlmOcrVisionConfig,
|
||||
norm_eps: float = 1e-5,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
|
||||
del self.post_conv_layernorm
|
||||
del self.embeddings
|
||||
|
||||
patch_size = vision_config.patch_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
in_channels = vision_config.in_channels
|
||||
depth = vision_config.depth
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
self.num_heads = vision_config.num_heads
|
||||
|
||||
self.patch_size = vision_config.patch_size
|
||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||
self.out_hidden_size = vision_config.out_hidden_size
|
||||
|
||||
self.patch_embed = Glm4vVisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=in_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = get_rope(
|
||||
head_size=head_dim,
|
||||
max_position=8192,
|
||||
is_neox_style=True,
|
||||
rope_parameters={"partial_rotary_factor": 0.5},
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
GlmOcrVisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
self.merger = GlmOcrPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=vision_config.out_hidden_size * vision_config.in_channels,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.merger",
|
||||
)
|
||||
|
||||
self.downsample = Conv2dLayer(
|
||||
in_channels=vision_config.hidden_size,
|
||||
out_channels=vision_config.out_hidden_size,
|
||||
kernel_size=vision_config.spatial_merge_size,
|
||||
stride=vision_config.spatial_merge_size,
|
||||
)
|
||||
self.post_layernorm = RMSNorm(
|
||||
vision_config.hidden_size, eps=vision_config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||
|
||||
# patchify
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
|
||||
grid_thw
|
||||
)
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||
|
||||
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
|
||||
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(
|
||||
x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
# adapter
|
||||
x = self.post_layernorm(x)
|
||||
|
||||
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.downsample(x).view(-1, self.out_hidden_size)
|
||||
x = self.merger(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Glm4vMultiModalProcessor,
|
||||
info=Glm4vProcessingInfo,
|
||||
dummy_inputs=Glm4vDummyInputsBuilder,
|
||||
)
|
||||
class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = GlmOcrVisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
285
vllm/model_executor/models/glm_ocr_mtp.py
Normal file
285
vllm/model_executor/models/glm_ocr_mtp.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2026 The ZhipuAI Team.
|
||||
# Copyright 2026 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-OCR MTP model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .glm4 import Glm4DecoderLayer, get_spec_layer_idx_from_weight_name
|
||||
from .glm4_moe_lite_mtp import (
|
||||
Glm4MoeLiteMultiTokenPredictor,
|
||||
SharedHead,
|
||||
)
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class GlmOcrMultiTokenPredictorLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
config = vllm_config.speculative_config.draft_model_config.hf_config.text_config
|
||||
self.config = config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
|
||||
|
||||
self.device = current_platform.device_type
|
||||
self.shared_head = SharedHead(
|
||||
config=config, prefix=prefix, quant_config=quant_config
|
||||
)
|
||||
self.mtp_block = Glm4DecoderLayer(
|
||||
vllm_config=vllm_config, prefix=prefix, config=self.config
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions[0] == 0] = 0
|
||||
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
|
||||
)
|
||||
|
||||
hidden_states, residual = self.mtp_block(
|
||||
positions=positions, hidden_states=hidden_states, residual=None
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlmOcrMultiTokenPredictor(Glm4MoeLiteMultiTokenPredictor):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config.text_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
self.layers = torch.nn.ModuleDict(
|
||||
{
|
||||
str(idx): GlmOcrMultiTokenPredictorLayer(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(
|
||||
self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers,
|
||||
)
|
||||
}
|
||||
)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
|
||||
class GlmOcrMTP(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config.text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
self.model = GlmOcrMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
self.expert_weights = []
|
||||
self.num_layers = self.config.num_nextn_predict_layers
|
||||
for layer in self.model.layers.values():
|
||||
assert isinstance(layer, GlmOcrMultiTokenPredictorLayer)
|
||||
layer = layer.mtp_block
|
||||
assert isinstance(layer, Glm4DecoderLayer)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor | None:
|
||||
return self.model.compute_logits(hidden_states, spec_step_idx)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight":
|
||||
spec_layer = self.model.mtp_start_layer_idx
|
||||
name = f"model.layers.{spec_layer}.shared_head.head.weight"
|
||||
elif name == "model.embed_tokens.weight":
|
||||
spec_layer = self.model.mtp_start_layer_idx
|
||||
else:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Some checkpoints include weight scale tensors for the
|
||||
# LM head even when the quantized head isn't built. Skip
|
||||
# them if the model does not expose a matching parameter
|
||||
# to avoid KeyError during load.
|
||||
if name.endswith(".weight_scale") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (
|
||||
spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name
|
||||
):
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
and rename shared layer weights to be top level.
|
||||
"""
|
||||
name = name.replace("model.language_model.layers", "model.layers")
|
||||
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens",
|
||||
"enorm",
|
||||
"hnorm",
|
||||
"eh_proj",
|
||||
"shared_head",
|
||||
]
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(
|
||||
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
|
||||
)
|
||||
elif shared_weight:
|
||||
# treat shared weights as top level weights
|
||||
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||
return name
|
||||
@@ -319,8 +319,9 @@ _MULTIMODAL_MODELS = {
|
||||
),
|
||||
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
|
||||
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
|
||||
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
|
||||
"GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501
|
||||
"GraniteSpeechForConditionalGeneration": (
|
||||
"granite_speech",
|
||||
"GraniteSpeechForConditionalGeneration",
|
||||
@@ -472,6 +473,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
|
||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
|
||||
"GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
|
||||
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
|
||||
|
||||
Reference in New Issue
Block a user