Enable conversion of multimodal models to pooling tasks (#24451)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
committed by
GitHub
parent
6a50eaa0d3
commit
e090b7b45b
@@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||
@@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
|
||||
class CallVisitor(ast.NodeVisitor):
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
self.calls.append(node.func.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
visitor = CallVisitor()
|
||||
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
|
||||
if "init_vllm_registered_model" not in visitor.calls:
|
||||
return None
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.pooler = self.get_language_model().pooler
|
||||
|
||||
return ModelForPooling # type: ignore
|
||||
|
||||
|
||||
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
@@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax(
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
model_config = model.vllm_config.model_config
|
||||
|
||||
tokens = getattr(model.config, "classifier_from_token", [])
|
||||
tokens = cast(list[int], tokens)
|
||||
assert len(tokens) == 2
|
||||
@@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax(
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
quant_config = model.vllm_config.quant_config
|
||||
model.lm_head = ParallelLMHead(model.config.vocab_size,
|
||||
model.config.hidden_size,
|
||||
quant_config=model.quant_config)
|
||||
quant_config=quant_config)
|
||||
|
||||
loader = AutoWeightsLoader(model)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
@@ -452,9 +492,10 @@ def load_weights_no_post_processing(model,
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
quant_config = model.vllm_config.quant_config
|
||||
model.lm_head = ParallelLMHead(model.config.vocab_size,
|
||||
model.config.hidden_size,
|
||||
quant_config=model.quant_config)
|
||||
quant_config=quant_config)
|
||||
|
||||
loader = AutoWeightsLoader(model)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user