Truncation control for embedding models (#14776)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Gabriel Marinho
2025-04-29 22:24:57 -03:00
committed by GitHub
parent 4055130a85
commit 1c2bc7ead0
21 changed files with 333 additions and 71 deletions

View File

@@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
resolve_chat_template_content_format)
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger
@@ -793,6 +794,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -807,6 +809,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -821,6 +824,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -836,6 +840,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[int],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -851,6 +856,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: list[list[int]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -864,6 +870,7 @@ class LLM:
prompts: None,
pooling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -882,6 +889,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -946,10 +954,15 @@ class LLM:
for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
)
@@ -962,6 +975,7 @@ class LLM:
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -995,6 +1009,7 @@ class LLM:
"Embedding API is only enabled for `--task embed`")
items = self.encode(prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
@@ -1055,6 +1070,7 @@ class LLM:
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
@@ -1098,9 +1114,8 @@ class LLM:
pooling_params = PoolingParams()
tokenization_kwargs: dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
parsed_prompts = []
@@ -1323,6 +1338,7 @@ class LLM:
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
) -> None:
@@ -1359,6 +1375,7 @@ class LLM:
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -1369,6 +1386,7 @@ class LLM:
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
@@ -1379,6 +1397,7 @@ class LLM:
prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

View File

@@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
@@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None
@@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel):
model: Optional[str] = None
text_1: Union[list[str], str]
text_2: Union[list[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-score-pooling-params
additional_data: Optional[Any] = None
@@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel):
query: str
documents: list[str]
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# doc: begin-rerank-pooling-params
additional_data: Optional[Any] = None

View File

@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
@@ -85,16 +86,7 @@ class OpenAIServingEmbedding(OpenAIServing):
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
truncate_prompt_tokens = request.truncate_prompt_tokens
pooling_params = request.to_pooling_params()
@@ -104,6 +96,8 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
(
lora_request,
prompt_adapter_request,

View File

@@ -173,7 +173,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if (self.model_config.encoder_config is not None
@@ -271,7 +271,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, list[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
@@ -292,7 +292,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
@@ -321,7 +321,7 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> list[TextTokensPrompt]:
"""
@@ -356,7 +356,7 @@ class OpenAIServing:
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(

View File

@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators
@@ -85,18 +86,11 @@ class OpenAIServingPooling(OpenAIServing):
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
truncate_prompt_tokens = request.truncate_prompt_tokens
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
(
lora_request,
prompt_adapter_request,

View File

@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -231,11 +232,6 @@ class ServingScores(OpenAIServing):
truncate_prompt_tokens: Optional[int] = None,
) -> list[PoolingRequestOutput]:
tokenization_kwargs: dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
(
lora_request,
prompt_adapter_request,
@@ -247,12 +243,9 @@ class ServingScores(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
tokenization_kwargs)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))

View File

@@ -46,4 +46,4 @@ def _validate_score_input_lens(
if len(texts_1) == 0:
raise ValueError("At least one text element must be given")
if len(texts_2) == 0:
raise ValueError("At least one text_pair element must be given")
raise ValueError("At least one text_pair element must be given")

View File

@@ -3,6 +3,7 @@
import asyncio
import functools
import os
from typing import Any, Optional
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
@@ -134,3 +135,26 @@ def cli_env_setup():
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _validate_truncation_size(
max_model_len: int,
truncate_prompt_tokens: Optional[int],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> Optional[int]:
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens <= -1:
truncate_prompt_tokens = max_model_len
if truncate_prompt_tokens > max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({max_model_len})."
f" Please, select a smaller truncation size.")
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
return truncate_prompt_tokens