[Misc] Split up pooling tasks (#10820)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -381,19 +381,20 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "generate":
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
messages = [
|
||||
"LLM.generate() is only supported for (conditional) generation "
|
||||
"models (XForCausalLM, XForConditionalGeneration).",
|
||||
]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "generate" in supported_tasks:
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "generate" in supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'generate' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task generate`.")
|
||||
"Your model supports the 'generate' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task generate`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
@@ -793,16 +794,18 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.encode() is only supported for embedding models."]
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
messages = ["LLM.encode() is only supported for pooling models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "pooling" in supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
"Your model supports the 'pooling' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task embed`, "
|
||||
"`--task classify`, `--task score` etc.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
@@ -864,21 +867,23 @@ class LLM:
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.score() is only supported for embedding models."]
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
messages = ["LLM.score() is only supported for pooling models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "pooling" in supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
"Your model supports the 'pooling' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task embed`, "
|
||||
"`--task classify`, `--task score` etc.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support the cross encoding")
|
||||
raise ValueError("Your model does not support cross encoding")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
|
||||
@@ -573,7 +573,7 @@ def init_app_state(
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -582,7 +582,7 @@ def init_app_state(
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
) if model_config.task == "generate" else None
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -590,13 +590,13 @@ def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embedding" else None
|
||||
) if model_config.runner_type == "pooling" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger
|
||||
) if (model_config.task == "embedding" \
|
||||
) if (model_config.runner_type == "pooling" \
|
||||
and model_config.is_cross_encoder) else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
|
||||
@@ -224,7 +224,7 @@ async def main(args):
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
) if model_config.runner_type == "generate" else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
@@ -232,7 +232,7 @@ async def main(args):
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embedding" else None
|
||||
) if model_config.runner_type == "pooling" else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
Reference in New Issue
Block a user