[BugFix] skip language model in Encoder (#30242)

Signed-off-by: dengyunyang <584797741@qq.com>
This commit is contained in:
dengyunyang
2025-12-22 21:25:59 +08:00
committed by GitHub
parent 2cf91c2ea4
commit 8f8f469b1b
8 changed files with 116 additions and 3 deletions

View File

@@ -520,3 +520,64 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
method = getattr(text_config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)
def as_mm_encoder_only_model(cls: _T) -> _T:
"""
Subclass an existing vLLM vl model to support mm encoder only for
EPD encoder instances.
"""
if not hasattr(cls, "embed_multimodal"):
# Submodel case: return the original class.
return cls
if not hasattr(cls, "get_language_model_spec"):
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
lm_model_cls, lm_attr = cls.get_language_model_spec()
if lm_model_cls is None or lm_attr is None:
raise TypeError(
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
)
class DummyLM(nn.Module):
def __init__(self, *args, **kwargs):
self.make_empty_intermediate_tensors = None
class ModelForMMEncoderOnly(cls):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
self.is_mm_encoder_only_model = True
origin_init = lm_model_cls.__init__
try:
lm_model_cls.__init__ = DummyLM.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
if hasattr(self, lm_attr):
delattr(self, lm_attr)
finally:
lm_model_cls.__init__ = origin_init
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
from .utils import AutoWeightsLoader
origin_init_ = AutoWeightsLoader.__init__
def _new_init_(self, *args, **kwargs):
origin_init_(self, *args, **kwargs)
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
try:
AutoWeightsLoader.__init__ = _new_init_
result = super().load_weights(weights)
finally:
AutoWeightsLoader.__init__ = origin_init_
return result
return ModelForMMEncoderOnly # type: ignore

View File

@@ -141,6 +141,14 @@ class SupportsMultiModal(Protocol):
"""
...
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return None, None
@overload
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
@@ -302,6 +310,10 @@ def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
return getattr(model, "supports_encoder_tp_data", False)
def supports_mm_encoder_only(model: type[object] | object) -> bool:
return getattr(model, "is_mm_encoder_only_model", False)
@overload
def supports_multimodal_pruning(
model: type[object],

View File

@@ -34,7 +34,7 @@ import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers import BatchFeature, Qwen2ForCausalLM
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
@@ -1567,3 +1567,11 @@ class Qwen2_5_VLForConditionalGeneration(
connector="visual.merger.",
tower_model="visual.",
)
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen2ForCausalLM, "language_model"

View File

@@ -2090,3 +2090,11 @@ class Qwen3VLForConditionalGeneration(
connector="visual.merger",
tower_model="visual.",
)
@classmethod
def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
"""
Return the language model spec:
(language model class, language model attr)
"""
return Qwen3LLMForCausalLM, "language_model"