[Frontend] Separate pooling APIs in offline inference (#11129)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -26,7 +26,9 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
@@ -120,7 +122,7 @@ class LLM:
|
||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||
"""
|
||||
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = True
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
||||
@@ -257,11 +259,14 @@ class LLM:
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@@ -275,6 +280,9 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@@ -288,6 +296,9 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@@ -302,6 +313,9 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@@ -316,6 +330,9 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@@ -328,6 +345,9 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@@ -678,11 +698,12 @@ class LLM:
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@@ -696,6 +717,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@@ -709,6 +731,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@@ -723,6 +746,7 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@@ -737,6 +761,7 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@@ -749,6 +774,7 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@@ -768,7 +794,8 @@ class LLM:
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
prompts.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
@@ -787,7 +814,7 @@ class LLM:
|
||||
|
||||
Returns:
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
|
||||
@@ -833,28 +860,110 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""
|
||||
Generate an embedding vector for each prompt.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
if self.llm_engine.model_config.task != "embed":
|
||||
raise ValueError(
|
||||
"Embedding API is only enabled for `--task embed`")
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def classify(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[ClassificationRequestOutput]:
|
||||
"""
|
||||
Generate class logits for each prompt.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``ClassificationRequestOutput`` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
if self.llm_engine.model_config.task != "classify":
|
||||
raise ValueError(
|
||||
"Classification API is only enabled for `--task classify`")
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
) -> List[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs ``<text,text_pair>``.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
the text_1 sentence will be replicated N times to pair with the text_2
|
||||
sentences. The input pairs are used to build a list of prompts for the
|
||||
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
|
||||
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
|
||||
times to pair with the ``text_2`` sentences.
|
||||
The input pairs are used to build a list of prompts for the
|
||||
cross encoder model. This class automatically batches the prompts,
|
||||
considering the memory constraint. For the best performance, put all
|
||||
of your texts into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
text_1: can be a single prompt or a list of prompts, in which
|
||||
case it has to have the same length as the text_2 list
|
||||
case it has to have the same length as the ``text_2`` list
|
||||
text_2: The texts to pair with the query to form the input
|
||||
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||
more details about the format of each prompts.
|
||||
@@ -864,7 +973,7 @@ class LLM:
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
A list of ``ScoringRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
@@ -884,6 +993,8 @@ class LLM:
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support cross encoding")
|
||||
if self.llm_engine.model_config.task != "score":
|
||||
raise ValueError("Score API is only enabled for `--task score`")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
@@ -954,8 +1065,10 @@ class LLM:
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
items = self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
Reference in New Issue
Block a user