[Bugfix] Fix Qwen3-VL-Reranker model loading for sequence classification (#32089)

Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
RickyChen / 陳昭儒
2026-01-11 04:40:09 +08:00
committed by GitHub
parent e15a5ff07b
commit 8020a60402
2 changed files with 18 additions and 12 deletions

View File

@@ -401,24 +401,23 @@ def load_weights_using_from_2_way_softmax(
tokens = cast(list[int], tokens)
assert len(tokens) == 2
model.lm_head = ParallelLMHead(
language_model = (
model.get_language_model() if hasattr(model, "get_language_model") else model
)
language_model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
)
if text_config.tie_word_embeddings:
# embed_tokens is the assumed name for input embeddings. If the model does not
# have this attribute, we fall back to get_input_embeddings(), which is used by
# the Transformers modeling backend.
text_backbone = (
model.get_language_model().model
if hasattr(model, "get_language_model")
else model.model
)
text_backbone = language_model.model
embed_tokens = (
text_backbone.embed_tokens
if hasattr(text_backbone, "embed_tokens")
else text_backbone.get_input_embeddings()
)
model.lm_head = model.lm_head.tie_weights(embed_tokens)
language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens)
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
# function, so we need use this hacky method to obtain it.
@@ -438,17 +437,22 @@ def load_weights_using_from_2_way_softmax(
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
score_weight = model.lm_head.weight.data[[true_id]].to(
lm_head_weight = language_model.lm_head.weight
score_weight = lm_head_weight.data[[true_id]].to(
torch.float32
) - model.lm_head.weight.data[[false_id]].to(torch.float32)
) - lm_head_weight.data[[false_id]].to(torch.float32)
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)
del model.lm_head
del language_model.lm_head
loaded_weights.add("score.weight")
loaded_weights.discard("lm_head.weight")
lm_head_name = "lm_head.weight"
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name)
loaded_weights.discard(lm_head_name)
return loaded_weights