[Model] Add user-configurable task for models that support both generation and embedding (#9424)
This commit is contained in:
@@ -3,7 +3,7 @@ import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||
Tuple, Type, Union, cast)
|
||||
Tuple, Type, Union, cast, get_args)
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
||||
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig)
|
||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
@@ -84,6 +84,7 @@ class EngineArgs:
|
||||
model: str = 'facebook/opt-125m'
|
||||
served_model_name: Optional[Union[str, List[str]]] = None
|
||||
tokenizer: Optional[str] = None
|
||||
task: TaskOption = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
@@ -198,6 +199,15 @@ class EngineArgs:
|
||||
type=str,
|
||||
default=EngineArgs.model,
|
||||
help='Name or path of the huggingface model to use.')
|
||||
parser.add_argument(
|
||||
'--task',
|
||||
default=EngineArgs.task,
|
||||
choices=get_args(TaskOption),
|
||||
help='The task to use the model for. Each vLLM instance only '
|
||||
'supports one task, even if the same model can be used for '
|
||||
'multiple tasks. When the model only supports one task, "auto" '
|
||||
'can be used to select it; otherwise, you must specify explicitly '
|
||||
'which task to use.')
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=nullable_str,
|
||||
@@ -838,6 +848,7 @@ class EngineArgs:
|
||||
def create_model_config(self) -> ModelConfig:
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
task=self.task,
|
||||
# We know this is not None because we set it in __post_init__
|
||||
tokenizer=cast(str, self.tokenizer),
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
@@ -1026,13 +1037,13 @@ class EngineArgs:
|
||||
" please file an issue with detailed information.")
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
task=model_config.task,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_model_len=model_config.max_model_len,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
embedding_mode=model_config.embedding_mode,
|
||||
is_multimodal_model=model_config.is_multimodal_model,
|
||||
preemption_mode=self.preemption_mode,
|
||||
num_scheduler_steps=self.num_scheduler_steps,
|
||||
|
||||
Reference in New Issue
Block a user