Extend renderer with embedding support and integrate completion endpoint (#24405)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@@ -2,12 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@@ -49,37 +53,121 @@ class BaseRenderer(ABC):
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""
|
||||
Convert input prompts into tokenized format for engine processing.
|
||||
|
||||
This is the core method that transforms various input formats into
|
||||
standardized TokensPrompt objects. Implementations should handle
|
||||
tokenization, special token insertion, truncation, and validation
|
||||
according to model requirements.
|
||||
|
||||
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||
|
||||
This method accepts text or token inputs and produces a
|
||||
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
|
||||
for the engine.
|
||||
|
||||
Args:
|
||||
prompt_or_prompts: Input data in various formats:
|
||||
- str: Single text prompt
|
||||
- list[str]: Batch of text prompts
|
||||
- list[int]: Pre-tokenized sequence
|
||||
- list[list[int]]: Batch of pre-tokenized sequences
|
||||
max_length: Maximum sequence length (endpoint-specific behavior)
|
||||
truncate_prompt_tokens: Truncate to last N tokens
|
||||
(None=no truncation, 0=empty)
|
||||
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
|
||||
to text inputs
|
||||
cache_salt: Optional string to disambiguate cached prompts
|
||||
|
||||
prompt_or_prompts: One of:
|
||||
- ``str``: Single text prompt.
|
||||
- ``list[str]``: Batch of text prompts.
|
||||
- ``list[int]``: Single pre-tokenized sequence.
|
||||
- ``list[list[int]]``: Batch of pre-tokenized sequences.
|
||||
max_length: Maximum allowable total input token length. If provided,
|
||||
token inputs longer than this raise ``ValueError``.
|
||||
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
|
||||
truncation. ``0`` yields an empty list (and skips embeds).
|
||||
``-1`` maps to ``model_config.max_model_len``.
|
||||
add_special_tokens: Whether to add model-specific special tokens
|
||||
during text tokenization.
|
||||
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||
input, detokenize IDs back to text for inclusion in outputs.
|
||||
|
||||
Returns:
|
||||
list[EngineTokensPrompt]: Tokenized prompts ready for engine
|
||||
consumption
|
||||
|
||||
list[EngineTokensPrompt]: Engine-ready token prompts.
|
||||
|
||||
Raises:
|
||||
ValueError: If input format is invalid or length limits exceeded
|
||||
ValueError: If input formats are invalid or length limits exceeded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Convert text/token and/or base64-encoded embeddings inputs into
|
||||
engine-ready prompt objects.
|
||||
|
||||
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
|
||||
provided and non-empty. If both are omitted or empty (e.g., empty
|
||||
string and empty list), a ``ValueError`` is raised.
|
||||
|
||||
Args:
|
||||
prompt_or_prompts: Text or token inputs to include.
|
||||
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
|
||||
torch-saved tensor to be used as prompt embeddings.
|
||||
max_length: Maximum allowable total input token length. If provided,
|
||||
inputs longer than this raise ``ValueError``.
|
||||
truncate_prompt_tokens: Number of tokens/rows to keep from the end
|
||||
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
|
||||
add_special_tokens: Whether to add model-specific special tokens
|
||||
during text tokenization.
|
||||
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||
input, detokenize IDs back to text for inclusion in outputs.
|
||||
|
||||
Returns:
|
||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
Engine-ready prompt objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
|
||||
are omitted or empty (decoder prompt cannot be empty), or if
|
||||
length limits are exceeded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load_prompt_embeds(
|
||||
cls,
|
||||
prompt_embeds: Union[bytes, list[bytes]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
) -> list[EngineEmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
|
||||
if cache_salt is not None:
|
||||
embeds_prompt["cache_salt"] = cache_salt
|
||||
return embeds_prompt
|
||||
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
||||
else:
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
|
||||
|
||||
class CompletionRenderer(BaseRenderer):
|
||||
|
||||
@@ -101,50 +189,110 @@ class CompletionRenderer(BaseRenderer):
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""Implementation of prompt rendering for completion-style requests.
|
||||
|
||||
Uses async tokenizer pooling for improved performance. See base class
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
f"Please select a smaller truncation size.")
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
truncate_prompt_tokens, max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
# Parse and batch the input prompts
|
||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||
|
||||
rendered_prompts: list[EngineTokensPrompt] = []
|
||||
tokenize_tasks = []
|
||||
tasks = []
|
||||
for prompt_input in batch_inputs:
|
||||
if prompt_input["is_tokens"] is True:
|
||||
# Token input
|
||||
token_ids = self._maybe_apply_truncation(
|
||||
prompt_input["content"], truncate_prompt_tokens)
|
||||
rendered_prompts.append(
|
||||
self._create_tokens_prompt(token_ids, max_length,
|
||||
cache_salt))
|
||||
detokenize_task = asyncio.create_task(
|
||||
# Note: detokenization is needed when echo is enabled,
|
||||
# where the input token IDs are decoded back to text.
|
||||
self._maybe_detokenize(prompt_input["content"], max_length,
|
||||
truncate_prompt_tokens, cache_salt,
|
||||
needs_detokenization))
|
||||
tasks.append(detokenize_task)
|
||||
else:
|
||||
# Text input
|
||||
tokenize_task = asyncio.create_task(
|
||||
self._tokenize(prompt_input["content"], max_length,
|
||||
truncate_prompt_tokens, add_special_tokens,
|
||||
cache_salt))
|
||||
tokenize_tasks.append(tokenize_task)
|
||||
tasks.append(tokenize_task)
|
||||
|
||||
# Wait for all text tokenization to finish
|
||||
if tokenize_tasks:
|
||||
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
|
||||
rendered_prompts.extend(tokenized_text_prompts)
|
||||
if tasks:
|
||||
tokenized_text_prompts = await asyncio.gather(*tasks)
|
||||
return tokenized_text_prompts
|
||||
|
||||
return rendered_prompts
|
||||
return []
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
truncate_prompt_tokens, max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
|
||||
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
|
||||
cache_salt))
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
token_prompts = await self.render_prompt(
|
||||
prompt_or_prompts=prompt_or_prompts,
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
cache_salt=cache_salt,
|
||||
needs_detokenization=needs_detokenization,
|
||||
)
|
||||
rendered.extend(token_prompts)
|
||||
|
||||
return rendered
|
||||
|
||||
def _validate_and_normalize_truncate_tokens(
|
||||
self,
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
max_length: Optional[int],
|
||||
) -> Optional[int]:
|
||||
"""Validate and normalize truncate_prompt_tokens parameter."""
|
||||
if truncate_prompt_tokens is None:
|
||||
return None
|
||||
|
||||
if truncate_prompt_tokens == 0:
|
||||
return 0
|
||||
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
|
||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
f"Please select a smaller truncation size.")
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
def _maybe_apply_truncation(
|
||||
self, token_ids: list[int],
|
||||
@@ -186,7 +334,29 @@ class CompletionRenderer(BaseRenderer):
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||
cache_salt)
|
||||
cache_salt, text)
|
||||
|
||||
async def _maybe_detokenize(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
max_length: Optional[int],
|
||||
truncate_prompt_tokens: Optional[int],
|
||||
cache_salt: Optional[str],
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
) -> EngineTokensPrompt:
|
||||
"""Optionally detokenize token IDs and build a tokens prompt."""
|
||||
token_ids = self._maybe_apply_truncation(token_ids,
|
||||
truncate_prompt_tokens)
|
||||
|
||||
prompt = None
|
||||
if needs_detokenization is True:
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
prompt = await async_tokenizer.decode(token_ids)
|
||||
|
||||
return self._create_tokens_prompt(token_ids=token_ids,
|
||||
max_length=max_length,
|
||||
cache_salt=cache_salt,
|
||||
prompt=prompt)
|
||||
|
||||
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
"""Get or create async tokenizer using shared pool."""
|
||||
@@ -210,6 +380,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
token_ids: list[int],
|
||||
max_length: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
) -> EngineTokensPrompt:
|
||||
"""Create validated EngineTokensPrompt."""
|
||||
if max_length is not None and len(token_ids) > max_length:
|
||||
@@ -221,4 +392,6 @@ class CompletionRenderer(BaseRenderer):
|
||||
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
|
||||
if cache_salt is not None:
|
||||
tokens_prompt["cache_salt"] = cache_salt
|
||||
return tokens_prompt
|
||||
if prompt is not None:
|
||||
tokens_prompt["prompt"] = prompt
|
||||
return tokens_prompt
|
||||
Reference in New Issue
Block a user