[Model][3/N] Automatic conversion of CrossEncoding model (#20168)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-07-04 20:47:39 +08:00
committed by GitHub
parent 9e5452ee34
commit 2e26f9156a
8 changed files with 234 additions and 133 deletions

View File

@@ -2,14 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
import torch
import torch.nn as nn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module])
@@ -39,7 +42,6 @@ def _create_pooling_model_cls(
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
@@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return cls
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
@@ -193,6 +194,7 @@ def as_seq_cls_model(cls: _T) -> _T:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
@@ -242,6 +244,17 @@ def as_seq_cls_model(cls: _T) -> _T:
]
return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
if tokens is None and method is None:
return super().load_weights(weights)
else:
# Online convert ForCausalLM into
# ForSequenceClassification model.
return seq_cls_model_loader(self, weights)
ModelForSequenceClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
@@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T:
_get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore
class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
method = getattr(config, "method", None)
tokens = getattr(config, "classifier_from_token", None)
if method is None:
return
assert tokens is not None
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
if method == "from_2_way_softmax":
assert len(tokens) == 2
config.num_labels = 1
else:
config.num_labels = len(tokens)
def load_weights_using_from_2_way_softmax(
model, weights: Iterable[tuple[str, torch.Tensor]]):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
tokens = getattr(model.config, "classifier_from_token", [])
tokens = cast(list[int], tokens)
assert len(tokens) == 2
device = model.score.weight.device
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
model.lm_head = ParallelLMHead(model.config.vocab_size,
model.config.hidden_size,
quant_config=model.quant_config)
loader = AutoWeightsLoader(model)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[true_id].to(device).to(
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
torch.float32)
model.score.weight.data.copy_(weight)
del model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
SEQ_CLS_LOAD_METHODS = {
"from_2_way_softmax": load_weights_using_from_2_way_softmax,
}
def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
# Online convert ForCausalLM into ForSequenceClassification model.
# - from_2_way_softmax:
# - Qwen3ForCausalLM
# - Qwen3-Reranker
# - Qwen2ForCausalLM
# - mxbai-rerank-v2
config = model.vllm_config.model_config.hf_config
method = getattr(config, "method", None)
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
return SEQ_CLS_LOAD_METHODS[method](model, weights)