[Model][3/N] Automatic conversion of CrossEncoding model (#20168)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-07-04 20:47:39 +08:00
committed by GitHub
parent 9e5452ee34
commit 2e26f9156a
8 changed files with 234 additions and 133 deletions

View File

@@ -38,15 +38,14 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
@@ -323,114 +322,4 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights)
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
SupportsCrossEncoding):
def __init__(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
self.prefix = prefix
self.model = Qwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
is_original_qwen3_reranker = getattr(self.config,
"is_original_qwen3_reranker",
False)
if not is_original_qwen3_reranker:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return self.load_weights_from_original_qwen3_reranker(weights)
def load_weights_from_original_qwen3_reranker(
self, weights: Iterable[tuple[str, torch.Tensor]]):
model_config = self.vllm_config.model_config
tokens = getattr(self.config, "classifier_from_token", None)
device = self.score.weight.device
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(
self.prefix, "lm_head"))
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(
model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
a = tokenizer.convert_tokens_to_ids(tokens[0])
b = tokenizer.convert_tokens_to_ids(tokens[1])
weight = self.lm_head.weight.data[b].to(
device) - self.lm_head.weight.data[a].to(device)
self.score.weight.data.copy_(weight)
del self.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
return loaded_weights
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)