[Model] Re-add the implicit conversion feature for as_seq_cls_model (#21103)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-07-18 15:15:07 +08:00
committed by GitHub
parent ba2dfbb0c2
commit ca4eb82bcb
11 changed files with 165 additions and 75 deletions

View File

@@ -551,7 +551,7 @@ class ModelConfig:
# For pooling models, self.task is used to indicate the
# user-selected task
if self.task == "score":
if self.registry.is_cross_encoder_model(self.architectures):
if self._is_classify_task(self.architectures):
self.task = "classify"
else:
self.task = "embed"
@@ -806,6 +806,12 @@ class ModelConfig:
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode
def _is_classify_task(self, architectures: list[str]):
for arch in architectures:
if arch.endswith("ForSequenceClassification"):
return True
return self.registry.is_cross_encoder_model(architectures)
def _get_preferred_pooling_task(
self,
architectures: list[str],
@@ -813,14 +819,11 @@ class ModelConfig:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForSequenceClassification", "classify"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
@@ -878,11 +881,14 @@ class ModelConfig:
self,
task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
if self._is_classify_task(self.architectures):
return {"generate": [], "pooling": ["classify"], "draft": []}
else:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
def _get_supported_runner_types(
self,
@@ -925,12 +931,16 @@ class ModelConfig:
f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}")
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("ForSequenceClassification", "pooling"),
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
@@ -940,10 +950,6 @@ class ModelConfig:
if arch.endswith(suffix) and pref_runner in supported_runner_types:
return pref_runner
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if "generate" in supported_runner_types:
return "generate"
if "pooling" in supported_runner_types:
@@ -1525,7 +1531,7 @@ class ModelConfig:
@property
def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions")
return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
or getattr(self.hf_config, "is_matryoshka", False))
@property
@@ -1539,13 +1545,11 @@ class ModelConfig:
return getattr(self.hf_config, "use_pad_token", True)
def get_and_verify_max_len(self, max_model_len: int):
# For pooling models, the tokenizer's `model_max_length` is often a
# reliable source for the maximum sequence length. However, for
# generative models, this can be incorrect and unduly limit the
# context window (e.g., DeepSeek-R1). Therefore, we only consider
# tokenizer_config for pooling models.
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
tokenizer_config = None
if self.runner_type == "pooling":
if (self.runner_type == "pooling" and getattr(
self.hf_config, "position_embedding_type", "") == "absolute"):
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,