[Core] Support multiple tasks per model (#20771)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
258
vllm/config.py
258
vllm/config.py
@@ -91,24 +91,19 @@ logger = init_logger(__name__)
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription"]
|
||||
"score", "reward", "transcription", "draft"]
|
||||
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
|
||||
"transcription"]
|
||||
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
|
||||
"classify", "reward", "draft"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
|
||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
||||
"generate": ["generate"],
|
||||
"pooling": ["embed", "classify", "reward"],
|
||||
"draft": ["draft"],
|
||||
"transcription": ["transcription"],
|
||||
}
|
||||
|
||||
_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = {
|
||||
task: runner
|
||||
for runner, tasks in _RUNNER_TASKS.items()
|
||||
for task in tasks
|
||||
"generate": ["generate", "transcription"],
|
||||
"pooling": ["pooling", "embed", "classify", "reward"],
|
||||
"draft": [],
|
||||
}
|
||||
|
||||
|
||||
@@ -234,11 +229,14 @@ class ModelConfig:
|
||||
"""Name or path of the Hugging Face model to use. It is also used as the
|
||||
content for `model_name` tag in metrics output when `served_model_name` is
|
||||
not specified."""
|
||||
task: Literal[TaskOption, Literal["draft"]] = "auto"
|
||||
"""The task to use the model for. Each vLLM instance only supports one
|
||||
task, even if the same model can be used for multiple tasks. When the model
|
||||
only supports one task, "auto" can be used to select it; otherwise, you
|
||||
must specify explicitly which task to use."""
|
||||
runner: RunnerOption = "auto"
|
||||
"""The type of model runner to use. Each vLLM instance only supports one
|
||||
model runner, even if the same model can be used for multiple types."""
|
||||
task: TaskOption = "auto"
|
||||
"""The task to use the model for. If the model supports more than one
|
||||
model runner, this is used to select which model runner to run.
|
||||
|
||||
Note that the model may support other tasks using the same model runner."""
|
||||
tokenizer: SkipValidation[str] = None # type: ignore
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
@@ -553,10 +551,41 @@ class ModelConfig:
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
# For pooling models, self.task is used to indicate the
|
||||
# user-selected task
|
||||
if self.task == "score":
|
||||
if self.registry.is_cross_encoder_model(self.architectures):
|
||||
self.task = "classify"
|
||||
else:
|
||||
self.task = "embed"
|
||||
elif self.task == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to 'embed', please "
|
||||
"use the new name. The old name will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
self.task = "embed"
|
||||
|
||||
all_supported_tasks = self._get_supported_tasks(self.task)
|
||||
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
|
||||
supported_runner_types = self._get_supported_runner_types(
|
||||
all_supported_tasks)
|
||||
runner_type = self._resolve_runner(self.runner, self.task,
|
||||
supported_runner_types,
|
||||
all_supported_tasks)
|
||||
|
||||
logger.debug("Selected runner type: %s", runner_type)
|
||||
# For pooling models, self.task is used to indicate the
|
||||
# user-selected task
|
||||
if runner_type == "pooling" and self.task == "auto":
|
||||
selected_task = all_supported_tasks[runner_type][-1]
|
||||
assert selected_task != "pooling"
|
||||
self.task = selected_task
|
||||
self.supported_runner_types = supported_runner_types
|
||||
self.runner_type = runner_type
|
||||
self.supported_tasks = all_supported_tasks[runner_type]
|
||||
|
||||
if self.runner_type in ("draft",
|
||||
"generate") and self.task != "transcription":
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
@@ -780,11 +809,10 @@ class ModelConfig:
|
||||
f"one of {get_args(TokenizerMode)}.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _get_preferred_task(
|
||||
def _get_preferred_pooling_task(
|
||||
self,
|
||||
architectures: list[str],
|
||||
supported_tasks: set[_ResolvedTask],
|
||||
) -> Optional[_ResolvedTask]:
|
||||
) -> _ResolvedTask:
|
||||
model_id = self.model
|
||||
if get_pooling_config(model_id, self.revision):
|
||||
return "embed"
|
||||
@@ -795,92 +823,136 @@ class ModelConfig:
|
||||
|
||||
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
|
||||
# Other models follow this pattern
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ForSequenceClassification", "classify"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("EmbeddingModel", "embed"),
|
||||
("RewardModel", "reward"),
|
||||
]
|
||||
_, arch = self.registry.inspect_model_cls(architectures)
|
||||
|
||||
for suffix, pref_task in suffix_to_preferred_task:
|
||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||
if arch.endswith(suffix):
|
||||
return pref_task
|
||||
|
||||
return None
|
||||
return "embed"
|
||||
|
||||
def _resolve_task(
|
||||
def _get_supported_generation_tasks(
|
||||
self,
|
||||
task_option: Literal[TaskOption, Literal["draft"]],
|
||||
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
|
||||
if task_option == "draft":
|
||||
return {"draft"}, "draft"
|
||||
|
||||
task_option: TaskOption,
|
||||
) -> list[_ResolvedTask]:
|
||||
registry = self.registry
|
||||
architectures = self.architectures
|
||||
|
||||
runner_support: dict[RunnerType, bool] = {
|
||||
# NOTE: Listed from highest to lowest priority,
|
||||
# in case the model supports multiple of them
|
||||
"transcription": registry.is_transcription_model(architectures),
|
||||
"generate": registry.is_text_generation_model(architectures),
|
||||
"pooling": registry.is_pooling_model(architectures),
|
||||
if registry.is_transcription_only_model(architectures):
|
||||
return ["transcription"]
|
||||
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if registry.is_text_generation_model(architectures):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if registry.is_transcription_model(architectures):
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def _get_supported_pooling_tasks(
|
||||
self,
|
||||
task_option: TaskOption,
|
||||
) -> list[_ResolvedTask]:
|
||||
registry = self.registry
|
||||
architectures = self.architectures
|
||||
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if registry.is_pooling_model(architectures):
|
||||
supported_tasks.append("pooling")
|
||||
|
||||
# For now, users must specify the task (other than "pooling")
|
||||
# to use for pooling models
|
||||
if task_option == "auto":
|
||||
preferred_task = self._get_preferred_pooling_task(
|
||||
architectures)
|
||||
|
||||
supported_tasks.append(preferred_task)
|
||||
elif task_option in _RUNNER_TASKS["pooling"]:
|
||||
supported_tasks.append(cast(_ResolvedTask, task_option))
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def _get_supported_tasks(
|
||||
self,
|
||||
task_option: TaskOption,
|
||||
) -> dict[RunnerType, list[_ResolvedTask]]:
|
||||
return {
|
||||
"generate": self._get_supported_generation_tasks(task_option),
|
||||
"pooling": self._get_supported_pooling_tasks(task_option),
|
||||
"draft": ["draft"]
|
||||
}
|
||||
supported_runner_types_lst: list[RunnerType] = [
|
||||
runner_type
|
||||
for runner_type, is_supported in runner_support.items()
|
||||
if is_supported
|
||||
]
|
||||
|
||||
supported_tasks_lst: list[_ResolvedTask] = [
|
||||
task for runner_type in supported_runner_types_lst
|
||||
for task in _RUNNER_TASKS[runner_type]
|
||||
]
|
||||
supported_tasks = set(supported_tasks_lst)
|
||||
def _get_supported_runner_types(
|
||||
self,
|
||||
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
|
||||
) -> set[RunnerType]:
|
||||
return {
|
||||
runner
|
||||
for runner, runner_tasks in supported_tasks.items()
|
||||
if len(runner_tasks) > 0
|
||||
}
|
||||
|
||||
if task_option == "auto":
|
||||
selected_task = next(iter(supported_tasks_lst))
|
||||
def _resolve_runner(
|
||||
self,
|
||||
runner_option: RunnerOption,
|
||||
task_option: TaskOption,
|
||||
supported_runner_types: set[RunnerType],
|
||||
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
|
||||
) -> RunnerType:
|
||||
if not supported_runner_types:
|
||||
raise ValueError("This model does not support any model runners!")
|
||||
|
||||
if len(supported_tasks_lst) > 1:
|
||||
preferred_task = self._get_preferred_task(
|
||||
architectures, supported_tasks)
|
||||
if preferred_task is not None:
|
||||
selected_task = preferred_task
|
||||
if runner_option != "auto":
|
||||
if runner_option not in supported_runner_types:
|
||||
raise ValueError(
|
||||
f"This model does not support runner={runner_option!r}. "
|
||||
f"Available runners: {supported_runner_types}")
|
||||
|
||||
logger.info(
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
else:
|
||||
if task_option == "score":
|
||||
if not runner_support["pooling"]:
|
||||
msg = (f"This model does not support the '{task_option}' "
|
||||
f"task. Supported tasks: {supported_tasks}")
|
||||
raise ValueError(msg)
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
task_option = "classify"
|
||||
else:
|
||||
task_option = "embed"
|
||||
return runner_option
|
||||
|
||||
if task_option != "auto":
|
||||
for runner, runner_tasks in supported_tasks.items():
|
||||
if task_option in runner_tasks:
|
||||
return runner
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to "
|
||||
"'embed', please use the new name. The old name "
|
||||
"will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
task_runner: RunnerType = next(
|
||||
runner for runner, tasks in _RUNNER_TASKS.items()
|
||||
if task_option in tasks)
|
||||
raise ValueError(
|
||||
f"This model does not support task={task_option!r}. "
|
||||
f"Available tasks for runner={task_runner!r}: "
|
||||
f"{supported_tasks[task_runner]}")
|
||||
|
||||
task_option = "embed"
|
||||
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("ForSequenceClassification", "pooling"),
|
||||
("EmbeddingModel", "pooling"),
|
||||
("RewardModel", "pooling"),
|
||||
]
|
||||
_, arch = self.registry.inspect_model_cls(self.architectures)
|
||||
|
||||
if task_option not in supported_tasks:
|
||||
msg = (
|
||||
f"This model does not support the '{task_option}' task. "
|
||||
f"Supported tasks: {supported_tasks}")
|
||||
raise ValueError(msg)
|
||||
for suffix, pref_runner in suffix_to_preferred_runner:
|
||||
if arch.endswith(suffix) and pref_runner in supported_runner_types:
|
||||
return pref_runner
|
||||
|
||||
selected_task = task_option
|
||||
if "classify" in supported_tasks.get("pooling", []):
|
||||
# When multiple pooling tasks are present, default to
|
||||
# pooling (eg cross-encoder) for non-standard architectures.
|
||||
return "pooling"
|
||||
if "generate" in supported_runner_types:
|
||||
return "generate"
|
||||
if "pooling" in supported_runner_types:
|
||||
return "pooling"
|
||||
|
||||
return supported_tasks, selected_task
|
||||
raise AssertionError("This line should not be reached")
|
||||
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
@@ -1449,14 +1521,6 @@ class ModelConfig:
|
||||
def use_mla(self) -> bool:
|
||||
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
|
||||
|
||||
@property
|
||||
def supported_runner_types(self) -> set[RunnerType]:
|
||||
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||
|
||||
@property
|
||||
def runner_type(self) -> RunnerType:
|
||||
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
|
||||
|
||||
@property
|
||||
def is_v1_compatible(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
@@ -2694,7 +2758,7 @@ class SpeculativeConfig:
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
task="draft",
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.
|
||||
|
||||
Reference in New Issue
Block a user