[Model][2/N] Automatic conversion of CrossEncoding model (#19978)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-07-03 21:59:23 +08:00
committed by GitHub
parent 1819fbda63
commit 6f1229f91d
16 changed files with 199 additions and 92 deletions

View File

@@ -93,14 +93,14 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
"transcription"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"pooling": ["embed", "classify", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
@@ -777,7 +777,7 @@ class ModelConfig:
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "score"
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"
@@ -841,14 +841,24 @@ class ModelConfig:
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if task_option == "score":
if not runner_support["pooling"]:
msg = (f"This model does not support the '{task_option}' "
f"task. Supported tasks: {supported_tasks}")
raise ValueError(msg)
if self.registry.is_cross_encoder_model(architectures):
task_option = "classify"
else:
task_option = "embed"
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_option = "embed"
task_option = "embed"
if task_option not in supported_tasks:
msg = (