[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -370,6 +370,31 @@ class ModelConfig:
|
||||
selected_task = next(iter(supported_tasks_lst))
|
||||
|
||||
if len(supported_tasks) > 1:
|
||||
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
|
||||
# Hardcode the models that are exceptions
|
||||
("AquilaModel", "generate"),
|
||||
("ChatGLMModel", "generate"),
|
||||
# Other models follow this pattern
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("EmbeddingModel", "embedding"),
|
||||
("RewardModel", "embedding"),
|
||||
("ForSequenceClassification", "embedding"),
|
||||
]
|
||||
info, arch = ModelRegistry.inspect_model_cls(architectures)
|
||||
|
||||
for suffix, pref_task in suffix_to_preferred_task:
|
||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||
selected_task = pref_task
|
||||
break
|
||||
else:
|
||||
if (arch.endswith("Model")
|
||||
and info.architecture.endswith("ForCausalLM")
|
||||
and "embedding" in supported_tasks):
|
||||
selected_task = "embedding"
|
||||
|
||||
logger.info(
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
|
||||
Reference in New Issue
Block a user