[Models] Allow converting Qwen3-VL into Reranker model (#31890)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -333,9 +333,14 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
text_config = self.config.get_text_config()
|
||||
tokens = getattr(text_config, "classifier_from_token", None)
|
||||
method = getattr(text_config, "method", None)
|
||||
hf_config = self.config
|
||||
text_config = hf_config.get_text_config()
|
||||
tokens = getattr(
|
||||
hf_config,
|
||||
"classifier_from_token",
|
||||
getattr(text_config, "classifier_from_token", None),
|
||||
)
|
||||
method = getattr(hf_config, "method", getattr(text_config, "method", None))
|
||||
|
||||
def auto_set_score_bias(weights):
|
||||
for name, weight in weights:
|
||||
@@ -366,9 +371,14 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
method = getattr(text_config, "method", None)
|
||||
tokens = getattr(text_config, "classifier_from_token", None)
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
text_config = hf_config.get_text_config()
|
||||
method = getattr(hf_config, "method", getattr(text_config, "method", None))
|
||||
tokens = getattr(
|
||||
hf_config,
|
||||
"classifier_from_token",
|
||||
getattr(text_config, "classifier_from_token", None),
|
||||
)
|
||||
|
||||
if method is None:
|
||||
return
|
||||
@@ -378,8 +388,10 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
|
||||
if method == "from_2_way_softmax":
|
||||
assert len(tokens) == 2
|
||||
hf_config.num_labels = 1
|
||||
text_config.num_labels = 1
|
||||
else:
|
||||
hf_config.num_labels = len(tokens)
|
||||
text_config.num_labels = len(tokens)
|
||||
|
||||
# `llm as reranker` defaults to not using separating token.
|
||||
@@ -396,9 +408,14 @@ def load_weights_using_from_2_way_softmax(
|
||||
|
||||
model_config = model.vllm_config.model_config
|
||||
quant_config = model.vllm_config.quant_config
|
||||
text_config = model.config.get_text_config()
|
||||
hf_config = model.config
|
||||
text_config = hf_config.get_text_config()
|
||||
|
||||
tokens = getattr(text_config, "classifier_from_token", [])
|
||||
tokens = getattr(
|
||||
hf_config,
|
||||
"classifier_from_token",
|
||||
getattr(text_config, "classifier_from_token", []),
|
||||
)
|
||||
tokens = cast(list[int], tokens)
|
||||
assert len(tokens) == 2
|
||||
|
||||
@@ -409,10 +426,15 @@ def load_weights_using_from_2_way_softmax(
|
||||
# embed_tokens is the assumed name for input embeddings. If the model does not
|
||||
# have this attribute, we fall back to get_input_embeddings(), which is used by
|
||||
# the Transformers modeling backend.
|
||||
text_backbone = (
|
||||
model.get_language_model().model
|
||||
if hasattr(model, "get_language_model")
|
||||
else model.model
|
||||
)
|
||||
embed_tokens = (
|
||||
model.model.embed_tokens
|
||||
if hasattr(model.model, "embed_tokens")
|
||||
else model.model.get_input_embeddings()
|
||||
text_backbone.embed_tokens
|
||||
if hasattr(text_backbone, "embed_tokens")
|
||||
else text_backbone.get_input_embeddings()
|
||||
)
|
||||
model.lm_head = model.lm_head.tie_weights(embed_tokens)
|
||||
|
||||
@@ -516,8 +538,9 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# - GemmaForCausalLM
|
||||
# - bge-reranker-v2-gemma
|
||||
|
||||
text_config = model.vllm_config.model_config.hf_config.get_text_config()
|
||||
method = getattr(text_config, "method", None)
|
||||
hf_config = model.vllm_config.model_config.hf_config
|
||||
text_config = hf_config.get_text_config()
|
||||
method = getattr(hf_config, "method", getattr(text_config, "method", None))
|
||||
assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported"
|
||||
return SEQ_CLS_LOAD_METHODS[method](model, weights)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user