[Model] Add user-configurable task for models that support both generation and embedding (#9424)
This commit is contained in:
@@ -8,7 +8,7 @@ from tqdm import tqdm
|
||||
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import EngineArgs, TaskOption
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_hf_chat_template,
|
||||
@@ -29,7 +29,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -108,6 +108,12 @@ class LLM:
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
||||
"""
|
||||
A flag to toggle whether to deprecate positional arguments in
|
||||
:meth:`LLM.__init__`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def deprecate_legacy_api(cls):
|
||||
@@ -117,6 +123,13 @@ class LLM:
|
||||
|
||||
cls.DEPRECATE_LEGACY = False
|
||||
|
||||
@deprecate_args(
|
||||
start_index=2, # Ignore self and model
|
||||
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
|
||||
additional_message=(
|
||||
"All positional arguments other than `model` will be "
|
||||
"replaced with keyword arguments in an upcoming version."),
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -139,6 +152,8 @@ class LLM:
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
@@ -153,6 +168,7 @@ class LLM:
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
task=task,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
@@ -316,10 +332,21 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
if self.llm_engine.model_config.embedding_mode:
|
||||
raise ValueError(
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "generate":
|
||||
messages = [
|
||||
"LLM.generate() is only supported for (conditional) generation "
|
||||
"models (XForCausalLM, XForConditionalGeneration).")
|
||||
"models (XForCausalLM, XForConditionalGeneration).",
|
||||
]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "generate" in supported_tasks:
|
||||
messages.append(
|
||||
"Your model supports the 'generate' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task generate`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
@@ -692,10 +719,18 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
if not self.llm_engine.model_config.embedding_mode:
|
||||
raise ValueError(
|
||||
"LLM.encode() is only supported for embedding models (XModel)."
|
||||
)
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.encode() is only supported for embedding models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
@@ -905,6 +940,3 @@ class LLM:
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.llm_engine.is_encoder_decoder_model()
|
||||
|
||||
def _is_embedding_model(self):
|
||||
return self.llm_engine.is_embedding_model()
|
||||
|
||||
Reference in New Issue
Block a user