[New Model]: Support Qwen3 Embedding & Reranker (#19260)
This commit is contained in:
@@ -38,13 +38,15 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
||||
@@ -319,3 +321,122 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
self.model = Qwen3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"))
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
hidden_states = self._pooler.extract_states(hidden_states,
|
||||
pooling_metadata)
|
||||
logits, _ = self.score(hidden_states)
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
is_original_qwen3_reranker = getattr(self.config,
|
||||
"is_original_qwen3_reranker",
|
||||
False)
|
||||
|
||||
if not is_original_qwen3_reranker:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
return self.load_weights_from_original_qwen3_reranker(weights)
|
||||
|
||||
def load_weights_from_original_qwen3_reranker(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
assert tokens is not None and len(tokens) == 2, \
|
||||
("Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
||||
|
||||
self.config.num_labels = 1
|
||||
model_config = self.vllm_config.model_config
|
||||
|
||||
device = self.score.weight.device
|
||||
self.score = RowParallelLinear(self.config.hidden_size,
|
||||
self.config.num_labels,
|
||||
quant_config=self.quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(
|
||||
self.prefix, "score")).to(device)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(
|
||||
self.prefix, "lm_head"))
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
revision=model_config.tokenizer_revision,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
a = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
b = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
weight = self.lm_head.weight.data[b].to(
|
||||
device) - self.lm_head.weight.data[a].to(device)
|
||||
self.score.weight.data.copy_(weight)
|
||||
|
||||
del self.lm_head
|
||||
loaded_weights.add("classifier.weight")
|
||||
loaded_weights.discard("lm_head.weight")
|
||||
|
||||
Reference in New Issue
Block a user