[Model][2/N] Automatic conversion of CrossEncoding model (#19978)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -93,14 +93,14 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription"]
|
||||
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
|
||||
"draft", "transcription"]
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
|
||||
"transcription"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
||||
"generate": ["generate"],
|
||||
"pooling": ["embed", "classify", "score", "reward"],
|
||||
"pooling": ["embed", "classify", "reward"],
|
||||
"draft": ["draft"],
|
||||
"transcription": ["transcription"],
|
||||
}
|
||||
@@ -777,7 +777,7 @@ class ModelConfig:
|
||||
if get_pooling_config(model_id, self.revision):
|
||||
return "embed"
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
return "score"
|
||||
return "classify"
|
||||
if self.registry.is_transcription_model(architectures):
|
||||
return "transcription"
|
||||
|
||||
@@ -841,14 +841,24 @@ class ModelConfig:
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to "
|
||||
"'embed', please use the new name. The old name "
|
||||
"will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
if task_option == "score":
|
||||
if not runner_support["pooling"]:
|
||||
msg = (f"This model does not support the '{task_option}' "
|
||||
f"task. Supported tasks: {supported_tasks}")
|
||||
raise ValueError(msg)
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
task_option = "classify"
|
||||
else:
|
||||
task_option = "embed"
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to "
|
||||
"'embed', please use the new name. The old name "
|
||||
"will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
task_option = "embed"
|
||||
task_option = "embed"
|
||||
|
||||
if task_option not in supported_tasks:
|
||||
msg = (
|
||||
|
||||
Reference in New Issue
Block a user