[bug-fix] GLM OCR Patch Merger context_dim (#37962)
Signed-off-by: JaredforReal <w13431838023@gmail.com>
This commit is contained in:
@@ -35,7 +35,10 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig
|
||||
from transformers.models.glm_ocr.configuration_glm_ocr import (
|
||||
GlmOcrTextConfig,
|
||||
GlmOcrVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||
@@ -250,12 +253,13 @@ class GlmOcrPatchMerger(Glm4vPatchMerger):
|
||||
class GlmOcrVisionTransformer(Glm4vVisionTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
text_config: "GlmOcrTextConfig",
|
||||
vision_config: "GlmOcrVisionConfig",
|
||||
norm_eps: float = 1e-5,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
super().__init__(text_config, vision_config, norm_eps, quant_config, prefix)
|
||||
|
||||
del self.post_conv_layernorm
|
||||
del self.embeddings
|
||||
@@ -301,7 +305,7 @@ class GlmOcrVisionTransformer(Glm4vVisionTransformer):
|
||||
)
|
||||
self.merger = GlmOcrPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=vision_config.out_hidden_size * vision_config.in_channels,
|
||||
context_dim=text_config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.merger",
|
||||
@@ -383,6 +387,7 @@ class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = GlmOcrVisionTransformer(
|
||||
config.text_config,
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
|
||||
Reference in New Issue
Block a user