[Frontend] Separate pooling APIs in offline inference (#11129)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-13 18:40:07 +08:00
committed by GitHub
parent f93bf2b189
commit eeec9e3390
21 changed files with 669 additions and 304 deletions

View File

@@ -26,7 +26,9 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@@ -120,7 +122,7 @@ class LLM:
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
DEPRECATE_LEGACY: ClassVar[bool] = False
DEPRECATE_LEGACY: ClassVar[bool] = True
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
@@ -257,11 +259,14 @@ class LLM:
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@@ -275,6 +280,9 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@@ -288,6 +296,9 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@@ -302,6 +313,9 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@@ -316,6 +330,9 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@@ -328,6 +345,9 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> List[RequestOutput]:
...
@@ -678,11 +698,12 @@ class LLM:
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@@ -696,6 +717,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@@ -709,6 +731,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@@ -723,6 +746,7 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@@ -737,6 +761,7 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@@ -749,6 +774,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
...
@@ -768,7 +794,8 @@ class LLM:
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts.
"""Apply pooling to the hidden states corresponding to the input
prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
@@ -787,7 +814,7 @@ class LLM:
Returns:
A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
pooled hidden states in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
@@ -833,28 +860,110 @@ class LLM:
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
def embed(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""
Generate an embedding vector for each prompt.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "embed":
raise ValueError(
"Embedding API is only enabled for `--task embed`")
items = self.encode(prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
return [EmbeddingRequestOutput.from_base(item) for item in items]
def classify(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ClassificationRequestOutput]:
"""
Generate class logits for each prompt.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``ClassificationRequestOutput`` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "classify":
raise ValueError(
"Classification API is only enabled for `--task classify`")
items = self.encode(prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
return [ClassificationRequestOutput.from_base(item) for item in items]
def score(
self,
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
) -> List[ScoringRequestOutput]:
"""Generate similarity scores for all pairs ``<text,text_pair>``.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
the text_1 sentence will be replicated N times to pair with the text_2
sentences. The input pairs are used to build a list of prompts for the
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
times to pair with the ``text_2`` sentences.
The input pairs are used to build a list of prompts for the
cross encoder model. This class automatically batches the prompts,
considering the memory constraint. For the best performance, put all
of your texts into a single list and pass it to this method.
Args:
text_1: can be a single prompt or a list of prompts, in which
case it has to have the same length as the text_2 list
case it has to have the same length as the ``text_2`` list
text_2: The texts to pair with the query to form the input
to the LLM. See :class:`~vllm.inputs.PromptType` for
more details about the format of each prompts.
@@ -864,7 +973,7 @@ class LLM:
generation, if any.
Returns:
A list of ``PoolingRequestOutput`` objects containing the
A list of ``ScoringRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
runner_type = self.llm_engine.model_config.runner_type
@@ -884,6 +993,8 @@ class LLM:
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support cross encoding")
if self.llm_engine.model_config.task != "score":
raise ValueError("Score API is only enabled for `--task score`")
tokenizer = self.llm_engine.get_tokenizer()
@@ -954,8 +1065,10 @@ class LLM:
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def start_profile(self) -> None:
self.llm_engine.start_profile()

View File

@@ -900,7 +900,7 @@ class EmbeddingResponse(OpenAIBaseModel):
class ScoreResponseData(OpenAIBaseModel):
index: int
object: str = "score"
score: Union[List[float], str]
score: float
class ScoreResponse(OpenAIBaseModel):

View File

@@ -18,14 +18,15 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
def _get_embedding(
output: PoolingOutput,
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
@@ -46,8 +47,10 @@ def request_output_to_embedding_response(
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
prompt_token_ids = final_res.prompt_token_ids
embedding = _get_embedding(final_res.outputs, encoding_format)
embedding = _get_embedding(embedding_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)

View File

@@ -31,7 +31,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
ModelPermission, ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
UnloadLoraAdapterRequest)
@@ -73,7 +73,7 @@ class LoRAModulePath:
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest,
EmbeddingCompletionRequest, ScoreRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
@@ -567,12 +567,14 @@ class OpenAIServing:
return None
@staticmethod
def _base_request_id(raw_request: Request,
def _base_request_id(raw_request: Optional[Request],
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
return raw_request.headers.get(
"X-Request-Id", default) if raw_request is not None else default
if raw_request is None:
return default
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_decoded_token(logprob: Logprob,

View File

@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
@@ -24,13 +24,13 @@ def request_output_to_score_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
score = None
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
if final_res is not None:
score = final_res.outputs.embedding
score_data = ScoreResponseData(index=idx, score=score)
data.append(score_data)
classify_res = ScoringRequestOutput.from_base(final_res)
score_data = ScoreResponseData(index=idx,
score=classify_res.outputs.score)
data.append(score_data)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,