Extend renderer with embedding support and integrate completion endpoint (#24405)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2025-09-09 10:46:46 -07:00
committed by GitHub
parent 9ad0688e43
commit 15cb047e25
9 changed files with 410 additions and 309 deletions

View File

@@ -2,12 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
from abc import ABC, abstractmethod
from typing import Annotated, Optional, Union
import pybase64
import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -49,37 +53,121 @@ class BaseRenderer(ABC):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[EngineTokensPrompt]:
"""
Convert input prompts into tokenized format for engine processing.
This is the core method that transforms various input formats into
standardized TokensPrompt objects. Implementations should handle
tokenization, special token insertion, truncation, and validation
according to model requirements.
Convert text or token inputs into engine-ready TokensPrompt objects.
This method accepts text or token inputs and produces a
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
for the engine.
Args:
prompt_or_prompts: Input data in various formats:
- str: Single text prompt
- list[str]: Batch of text prompts
- list[int]: Pre-tokenized sequence
- list[list[int]]: Batch of pre-tokenized sequences
max_length: Maximum sequence length (endpoint-specific behavior)
truncate_prompt_tokens: Truncate to last N tokens
(None=no truncation, 0=empty)
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
to text inputs
cache_salt: Optional string to disambiguate cached prompts
prompt_or_prompts: One of:
- ``str``: Single text prompt.
- ``list[str]``: Batch of text prompts.
- ``list[int]``: Single pre-tokenized sequence.
- ``list[list[int]]``: Batch of pre-tokenized sequences.
max_length: Maximum allowable total input token length. If provided,
token inputs longer than this raise ``ValueError``.
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
truncation. ``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
Returns:
list[EngineTokensPrompt]: Tokenized prompts ready for engine
consumption
list[EngineTokensPrompt]: Engine-ready token prompts.
Raises:
ValueError: If input format is invalid or length limits exceeded
ValueError: If input formats are invalid or length limits exceeded.
"""
raise NotImplementedError
@abstractmethod
async def render_prompt_and_embeds(
self,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects.
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
provided and non-empty. If both are omitted or empty (e.g., empty
string and empty list), a ``ValueError`` is raised.
Args:
prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
max_length: Maximum allowable total input token length. If provided,
inputs longer than this raise ``ValueError``.
truncate_prompt_tokens: Number of tokens/rows to keep from the end
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
Engine-ready prompt objects.
Raises:
ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
are omitted or empty (decoder prompt cannot be empty), or if
length limits are exceeded.
"""
raise NotImplementedError
@classmethod
def load_prompt_embeds(
cls,
prompt_embeds: Union[bytes, list[bytes]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
cache_salt: Optional[str] = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
else:
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
@@ -101,50 +189,110 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens == 0:
return []
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length)
if truncate_prompt_tokens == 0:
return []
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
rendered_prompts: list[EngineTokensPrompt] = []
tokenize_tasks = []
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
token_ids = self._maybe_apply_truncation(
prompt_input["content"], truncate_prompt_tokens)
rendered_prompts.append(
self._create_tokens_prompt(token_ids, max_length,
cache_salt))
detokenize_task = asyncio.create_task(
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
self._maybe_detokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, cache_salt,
needs_detokenization))
tasks.append(detokenize_task)
else:
# Text input
tokenize_task = asyncio.create_task(
self._tokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, add_special_tokens,
cache_salt))
tokenize_tasks.append(tokenize_task)
tasks.append(tokenize_task)
# Wait for all text tokenization to finish
if tokenize_tasks:
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
rendered_prompts.extend(tokenized_text_prompts)
if tasks:
tokenized_text_prompts = await asyncio.gather(*tasks)
return tokenized_text_prompts
return rendered_prompts
return []
async def render_prompt_and_embeds(
self,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length)
if truncate_prompt_tokens == 0:
return []
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
cache_salt))
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
max_length=max_length,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
cache_salt=cache_salt,
needs_detokenization=needs_detokenization,
)
rendered.extend(token_prompts)
return rendered
def _validate_and_normalize_truncate_tokens(
self,
truncate_prompt_tokens: Optional[int],
max_length: Optional[int],
) -> Optional[int]:
"""Validate and normalize truncate_prompt_tokens parameter."""
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
return truncate_prompt_tokens
def _maybe_apply_truncation(
self, token_ids: list[int],
@@ -186,7 +334,29 @@ class CompletionRenderer(BaseRenderer):
max_length=truncate_prompt_tokens)
return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt)
cache_salt, text)
async def _maybe_detokenize(
self,
token_ids: list[int],
max_length: Optional[int],
truncate_prompt_tokens: Optional[int],
cache_salt: Optional[str],
needs_detokenization: Optional[bool] = False,
) -> EngineTokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids,
truncate_prompt_tokens)
prompt = None
if needs_detokenization is True:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)
return self._create_tokens_prompt(token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
@@ -210,6 +380,7 @@ class CompletionRenderer(BaseRenderer):
token_ids: list[int],
max_length: Optional[int] = None,
cache_salt: Optional[str] = None,
prompt: Optional[str] = None,
) -> EngineTokensPrompt:
"""Create validated EngineTokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
@@ -221,4 +392,6 @@ class CompletionRenderer(BaseRenderer):
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
return tokens_prompt
if prompt is not None:
tokens_prompt["prompt"] = prompt
return tokens_prompt