[bug-fix] GLM OCR Patch Merger context_dim (#37962)
Signed-off-by: JaredforReal <w13431838023@gmail.com>
This commit is contained in:
@@ -38,7 +38,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, Glm4vProcessor
|
||||
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
|
||||
from transformers.models.glm4v.configuration_glm4v import (
|
||||
Glm4vTextConfig,
|
||||
Glm4vVisionConfig,
|
||||
)
|
||||
from transformers.models.glm4v.image_processing_glm4v import (
|
||||
Glm4vImageProcessor,
|
||||
smart_resize,
|
||||
@@ -604,6 +607,7 @@ class Glm4vVisionEmbeddings(nn.Module):
|
||||
class Glm4vVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_config: Glm4vTextConfig,
|
||||
vision_config: Glm4vVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
@@ -1424,6 +1428,7 @@ class Glm4vForConditionalGeneration(
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Glm4vVisionTransformer(
|
||||
config.text_config,
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
|
||||
@@ -35,7 +35,10 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig
|
||||
from transformers.models.glm_ocr.configuration_glm_ocr import (
|
||||
GlmOcrTextConfig,
|
||||
GlmOcrVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||
@@ -250,12 +253,13 @@ class GlmOcrPatchMerger(Glm4vPatchMerger):
|
||||
class GlmOcrVisionTransformer(Glm4vVisionTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
text_config: "GlmOcrTextConfig",
|
||||
vision_config: "GlmOcrVisionConfig",
|
||||
norm_eps: float = 1e-5,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
super().__init__(text_config, vision_config, norm_eps, quant_config, prefix)
|
||||
|
||||
del self.post_conv_layernorm
|
||||
del self.embeddings
|
||||
@@ -301,7 +305,7 @@ class GlmOcrVisionTransformer(Glm4vVisionTransformer):
|
||||
)
|
||||
self.merger = GlmOcrPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=vision_config.out_hidden_size * vision_config.in_channels,
|
||||
context_dim=text_config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.merger",
|
||||
@@ -383,6 +387,7 @@ class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = GlmOcrVisionTransformer(
|
||||
config.text_config,
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
|
||||
Reference in New Issue
Block a user