[bug-fix] GLM OCR Patch Merger context_dim (#37962)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
This commit is contained in:
Jared Wen
2026-03-26 20:11:21 +08:00
committed by GitHub
parent dcdc145893
commit 757eafcf37
2 changed files with 14 additions and 4 deletions

View File

@@ -38,7 +38,10 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, Glm4vProcessor
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.configuration_glm4v import (
Glm4vTextConfig,
Glm4vVisionConfig,
)
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
smart_resize,
@@ -604,6 +607,7 @@ class Glm4vVisionEmbeddings(nn.Module):
class Glm4vVisionTransformer(nn.Module):
def __init__(
self,
text_config: Glm4vTextConfig,
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
@@ -1424,6 +1428,7 @@ class Glm4vForConditionalGeneration(
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Glm4vVisionTransformer(
config.text_config,
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,

View File

@@ -35,7 +35,10 @@ import torch.nn as nn
from einops import rearrange
if TYPE_CHECKING:
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig
from transformers.models.glm_ocr.configuration_glm_ocr import (
GlmOcrTextConfig,
GlmOcrVisionConfig,
)
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
@@ -250,12 +253,13 @@ class GlmOcrPatchMerger(Glm4vPatchMerger):
class GlmOcrVisionTransformer(Glm4vVisionTransformer):
def __init__(
self,
text_config: "GlmOcrTextConfig",
vision_config: "GlmOcrVisionConfig",
norm_eps: float = 1e-5,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix)
super().__init__(text_config, vision_config, norm_eps, quant_config, prefix)
del self.post_conv_layernorm
del self.embeddings
@@ -301,7 +305,7 @@ class GlmOcrVisionTransformer(Glm4vVisionTransformer):
)
self.merger = GlmOcrPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=vision_config.out_hidden_size * vision_config.in_channels,
context_dim=text_config.intermediate_size,
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.merger",
@@ -383,6 +387,7 @@ class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = GlmOcrVisionTransformer(
config.text_config,
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,