[BugFix] skip language model in Encoder (#30242)
Signed-off-by: dengyunyang <584797741@qq.com>
This commit is contained in:
@@ -520,3 +520,64 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
method = getattr(text_config, "method", None)
|
||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
||||
|
||||
|
||||
def as_mm_encoder_only_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM vl model to support mm encoder only for
|
||||
EPD encoder instances.
|
||||
"""
|
||||
if not hasattr(cls, "embed_multimodal"):
|
||||
# Submodel case: return the original class.
|
||||
return cls
|
||||
|
||||
if not hasattr(cls, "get_language_model_spec"):
|
||||
raise TypeError(f"{cls} need to implement `get_language_model_spec` method.")
|
||||
|
||||
lm_model_cls, lm_attr = cls.get_language_model_spec()
|
||||
|
||||
if lm_model_cls is None or lm_attr is None:
|
||||
raise TypeError(
|
||||
f"{cls}.get_language_model_spec() must return (lm_model_cls, lm_attr)"
|
||||
)
|
||||
|
||||
class DummyLM(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.make_empty_intermediate_tensors = None
|
||||
|
||||
class ModelForMMEncoderOnly(cls):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.is_mm_encoder_only_model = True
|
||||
origin_init = lm_model_cls.__init__
|
||||
try:
|
||||
lm_model_cls.__init__ = DummyLM.__init__
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
if hasattr(self, lm_attr):
|
||||
delattr(self, lm_attr)
|
||||
finally:
|
||||
lm_model_cls.__init__ = origin_init
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
origin_init_ = AutoWeightsLoader.__init__
|
||||
|
||||
def _new_init_(self, *args, **kwargs):
|
||||
origin_init_(self, *args, **kwargs)
|
||||
self.skip_prefixes = (self.skip_prefixes or []) + [f"{lm_attr}."]
|
||||
|
||||
try:
|
||||
AutoWeightsLoader.__init__ = _new_init_
|
||||
result = super().load_weights(weights)
|
||||
finally:
|
||||
AutoWeightsLoader.__init__ = origin_init_
|
||||
return result
|
||||
|
||||
return ModelForMMEncoderOnly # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user