[V1] Get supported tasks from model runner instead of model config (#21585)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -14,6 +14,7 @@ import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@@ -335,6 +336,14 @@ async def run_batch(
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
supported_tasks = await engine_client \
|
||||
.get_supported_tasks() # type: ignore
|
||||
else:
|
||||
supported_tasks = model_config.supported_tasks
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
@@ -351,7 +360,7 @@ async def run_batch(
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
) if "generate" in supported_tasks else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -359,19 +368,17 @@ async def run_batch(
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if "embed" in model_config.supported_tasks else None
|
||||
) if "embed" in supported_tasks else None
|
||||
|
||||
enable_serving_reranking = ("classify" in model_config.supported_tasks
|
||||
and getattr(model_config.hf_config,
|
||||
"num_labels", 0) == 1)
|
||||
enable_serving_reranking = ("classify" in supported_tasks and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if ("embed" in model_config.supported_tasks
|
||||
or enable_serving_reranking) else None
|
||||
) if ("embed" in supported_tasks or enable_serving_reranking) else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
Reference in New Issue
Block a user