[Model] Add user-configurable task for models that support both generation and embedding (#9424)

This commit is contained in:
Cyrus Leung
2024-10-19 02:31:58 +08:00
committed by GitHub
parent 7dbe738d65
commit 051eaf6db3
33 changed files with 451 additions and 201 deletions

View File

@@ -8,7 +8,7 @@ from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
@@ -29,7 +29,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
@@ -108,6 +108,12 @@ class LLM:
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
@@ -117,6 +123,13 @@ class LLM:
cls.DEPRECATE_LEGACY = False
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def __init__(
self,
model: str,
@@ -139,6 +152,8 @@ class LLM:
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
**kwargs,
) -> None:
'''
@@ -153,6 +168,7 @@ class LLM:
engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
@@ -316,10 +332,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
task = self.llm_engine.model_config.task
if task != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).")
"models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
@@ -692,10 +719,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
@@ -905,6 +940,3 @@ class LLM:
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()