Truncation control for embedding models (#14776)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
@@ -793,6 +794,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -807,6 +809,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -821,6 +824,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[list[int]]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -836,6 +840,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -851,6 +856,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[list[int]],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -864,6 +870,7 @@ class LLM:
|
||||
prompts: None,
|
||||
pooling_params: None,
|
||||
prompt_token_ids: Union[list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -882,6 +889,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -946,10 +954,15 @@ class LLM:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(self.llm_engine.model_config)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
@@ -962,6 +975,7 @@ class LLM:
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
@@ -995,6 +1009,7 @@ class LLM:
|
||||
"Embedding API is only enabled for `--task embed`")
|
||||
|
||||
items = self.encode(prompts,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
@@ -1055,6 +1070,7 @@ class LLM:
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
text_1 + text_2,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
@@ -1098,9 +1114,8 @@ class LLM:
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
parsed_prompts = []
|
||||
|
||||
@@ -1323,6 +1338,7 @@ class LLM:
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
@@ -1359,6 +1375,7 @@ class LLM:
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@@ -1369,6 +1386,7 @@ class LLM:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
@@ -1379,6 +1397,7 @@ class LLM:
|
||||
prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
@@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-chat-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
@@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
model: Optional[str] = None
|
||||
text_1: Union[list[str], str]
|
||||
text_2: Union[list[str], str]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-score-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
@@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
|
||||
# doc: begin-rerank-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
@@ -85,16 +86,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = None
|
||||
|
||||
if request.truncate_prompt_tokens is not None:
|
||||
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
@@ -104,6 +96,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
|
||||
@@ -173,7 +173,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: str,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
if (self.model_config.encoder_config is not None
|
||||
@@ -271,7 +271,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, list[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
@@ -292,7 +292,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
@@ -321,7 +321,7 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> list[TextTokensPrompt]:
|
||||
"""
|
||||
@@ -356,7 +356,7 @@ class OpenAIServing:
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
|
||||
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.utils import merge_async_iterators
|
||||
@@ -85,18 +86,11 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = None
|
||||
|
||||
if request.truncate_prompt_tokens is not None:
|
||||
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -231,11 +232,6 @@ class ServingScores(OpenAIServing):
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
@@ -247,12 +243,9 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
@@ -46,4 +46,4 @@ def _validate_score_input_lens(
|
||||
if len(texts_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(texts_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -134,3 +135,26 @@ def cli_env_setup():
|
||||
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
|
||||
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _validate_truncation_size(
|
||||
max_model_len: int,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[int]:
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= -1:
|
||||
truncate_prompt_tokens = max_model_len
|
||||
|
||||
if truncate_prompt_tokens > max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
Reference in New Issue
Block a user