Signed-off-by: fake0fan <645327136@qq.com> Signed-off-by: Chenguang ZHENG <645327136@qq.com>
769 lines
26 KiB
Python
769 lines
26 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import asyncio
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Mapping, Sequence
|
|
from functools import cached_property
|
|
from typing import TYPE_CHECKING, Any, Generic, overload
|
|
|
|
from typing_extensions import TypeVar
|
|
|
|
from vllm.inputs import (
|
|
EmbedsInputs,
|
|
EmbedsPrompt,
|
|
EncoderDecoderInputs,
|
|
ProcessorInputs,
|
|
SingletonInputs,
|
|
TextPrompt,
|
|
TokenInputs,
|
|
TokensPrompt,
|
|
)
|
|
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
|
|
from vllm.logger import init_logger
|
|
from vllm.tokenizers import TokenizerLike
|
|
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
|
from vllm.utils.counter import AtomicCounter
|
|
from vllm.utils.torch_utils import set_default_torch_num_threads
|
|
from vllm.v1.metrics.stats import MultiModalCacheStats
|
|
|
|
from .embed_utils import safe_load_prompt_embeds
|
|
from .inputs import (
|
|
DictPrompt,
|
|
EncoderDecoderDictPrompt,
|
|
EncoderDecoderTokPrompt,
|
|
SingletonDictPrompt,
|
|
SingletonTokPrompt,
|
|
TokPrompt,
|
|
)
|
|
from .inputs.preprocess import extract_target_prompt
|
|
from .params import ChatParams, TokenizeParams
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
from vllm.entrypoints.chat_utils import (
|
|
ChatCompletionMessageParam,
|
|
ConversationMessage,
|
|
)
|
|
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalInputs,
|
|
MultiModalUUIDDict,
|
|
)
|
|
from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
|
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
|
|
|
|
|
|
class BaseRenderer(ABC, Generic[_T]):
|
|
@classmethod
|
|
@abstractmethod
|
|
def from_config(
|
|
cls,
|
|
config: "VllmConfig",
|
|
tokenizer_kwargs: dict[str, Any],
|
|
) -> "BaseRenderer":
|
|
raise NotImplementedError
|
|
|
|
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.model_config = config.model_config
|
|
self.api_process_rank = config.parallel_config._api_process_rank
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
# Lazy initialization since offline LLM doesn't use async
|
|
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
|
|
|
self.mm_processor: BaseMultiModalProcessor | None = None
|
|
self._mm_cache_stats: MultiModalCacheStats | None = None
|
|
if config.model_config.is_multimodal_model:
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
|
|
from vllm.multimodal.registry import MultiModalTimingRegistry
|
|
|
|
mm_processor_cache = mm_registry.processor_cache_from_config(config)
|
|
|
|
with set_default_torch_num_threads():
|
|
self.mm_processor = mm_registry.create_processor(
|
|
config.model_config,
|
|
tokenizer=tokenizer,
|
|
cache=mm_processor_cache,
|
|
)
|
|
|
|
if mm_processor_cache:
|
|
self._mm_cache_stats = MultiModalCacheStats()
|
|
|
|
# This is used to generate internal request ID for MM processing
|
|
# It has no relation to the request ID for engine core
|
|
self._mm_req_counter = AtomicCounter()
|
|
self._mm_timing_registry = MultiModalTimingRegistry(
|
|
config.observability_config
|
|
)
|
|
|
|
def get_tokenizer(self) -> _T:
|
|
tokenizer = self.tokenizer
|
|
if tokenizer is None:
|
|
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
|
|
|
return tokenizer
|
|
|
|
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
|
if self._async_tokenizer is None:
|
|
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
|
|
|
return self._async_tokenizer
|
|
|
|
def get_mm_processor(self) -> "BaseMultiModalProcessor":
|
|
if self.mm_processor is None:
|
|
raise ValueError("Multi-modal processor not available for text-only models")
|
|
|
|
return self.mm_processor
|
|
|
|
@property
|
|
def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
|
|
if self.mm_processor is None:
|
|
return None
|
|
|
|
return self.mm_processor.cache
|
|
|
|
def stat_mm_cache(self) -> MultiModalCacheStats | None:
|
|
mm_cache_stats = self._mm_cache_stats
|
|
if mm_cache_stats is None:
|
|
return None
|
|
|
|
self._mm_cache_stats = MultiModalCacheStats()
|
|
|
|
return mm_cache_stats
|
|
|
|
def update_mm_cache_stats(self) -> None:
|
|
mm_processor_cache = self.mm_processor_cache
|
|
mm_cache_stats = self._mm_cache_stats
|
|
|
|
if mm_processor_cache and mm_cache_stats:
|
|
delta = mm_processor_cache.make_stats(delta=True)
|
|
mm_cache_stats.record(delta.total, delta.hits)
|
|
|
|
def clear_mm_cache(self) -> None:
|
|
mm_processor_cache = self.mm_processor_cache
|
|
if mm_processor_cache is not None:
|
|
mm_processor_cache.clear_cache()
|
|
|
|
if self._mm_cache_stats is not None:
|
|
self._mm_cache_stats.reset = True
|
|
|
|
def shutdown(self) -> None:
|
|
mm_processor_cache = self.mm_processor_cache
|
|
if mm_processor_cache is not None:
|
|
mm_processor_cache.close()
|
|
|
|
def get_bos_token_id(self) -> int | None:
|
|
if self.tokenizer is None:
|
|
logger.warning_once(
|
|
"Using None for BOS token id because tokenizer is not initialized"
|
|
)
|
|
return None
|
|
|
|
return self.tokenizer.bos_token_id
|
|
|
|
def get_eos_token_id(self) -> int | None:
|
|
if self.tokenizer is None:
|
|
logger.warning_once(
|
|
"Using None for EOS token id because tokenizer is not initialized"
|
|
)
|
|
return None
|
|
|
|
return self.tokenizer.eos_token_id
|
|
|
|
def get_dec_start_token_id(self) -> int:
|
|
"""
|
|
Obtain the decoder start token id employed by an encoder/decoder model,
|
|
raising an error if it is not available.
|
|
"""
|
|
dec_start_token_id = getattr(
|
|
self.model_config.hf_config, "decoder_start_token_id", None
|
|
)
|
|
|
|
if dec_start_token_id is None:
|
|
logger.warning_once(
|
|
"Falling back on <BOS> for decoder start token id "
|
|
"because decoder start token id is not available."
|
|
)
|
|
dec_start_token_id = self.get_bos_token_id()
|
|
|
|
if dec_start_token_id is None:
|
|
raise RuntimeError("Cannot find decoder start token id or <BOS>")
|
|
|
|
return dec_start_token_id
|
|
|
|
@cached_property
|
|
def default_cmpl_tok_params(self) -> TokenizeParams:
|
|
mm_processor = self.mm_processor
|
|
if mm_processor is not None:
|
|
return mm_processor.info.default_tok_params
|
|
|
|
model_config = self.model_config
|
|
encoder_config = model_config.encoder_config or {}
|
|
|
|
return TokenizeParams(
|
|
max_total_tokens=model_config.max_model_len,
|
|
do_lower_case=encoder_config.get("do_lower_case", False),
|
|
add_special_tokens=True,
|
|
)
|
|
|
|
@cached_property
|
|
def default_chat_tok_params(self) -> TokenizeParams:
|
|
mm_processor = self.mm_processor
|
|
if mm_processor is not None:
|
|
return mm_processor.info.default_tok_params
|
|
|
|
model_config = self.model_config
|
|
encoder_config = model_config.encoder_config or {}
|
|
|
|
return TokenizeParams(
|
|
max_total_tokens=model_config.max_model_len,
|
|
do_lower_case=encoder_config.get("do_lower_case", False),
|
|
add_special_tokens=False,
|
|
)
|
|
|
|
# Step 1: Convert raw inputs to prompts
|
|
def render_prompt(
|
|
self,
|
|
prompt: DictPrompt | bytes,
|
|
) -> DictPrompt:
|
|
if isinstance(prompt, bytes):
|
|
embeds = safe_load_prompt_embeds(self.model_config, prompt)
|
|
prompt = EmbedsPrompt(prompt_embeds=embeds)
|
|
|
|
return prompt
|
|
|
|
def render_prompts(
|
|
self,
|
|
prompts: Sequence[DictPrompt | bytes],
|
|
) -> list[DictPrompt]:
|
|
if len(prompts) == 0:
|
|
raise ValueError("You must pass at least one prompt")
|
|
|
|
return [self.render_prompt(prompt) for prompt in prompts]
|
|
|
|
async def render_prompts_async(
|
|
self,
|
|
prompts: Sequence[DictPrompt | bytes],
|
|
) -> list[DictPrompt]:
|
|
return self.render_prompts(prompts)
|
|
|
|
@abstractmethod
|
|
def render_messages(
|
|
self,
|
|
messages: list["ChatCompletionMessageParam"],
|
|
params: ChatParams,
|
|
) -> tuple[list["ConversationMessage"], DictPrompt]:
|
|
raise NotImplementedError
|
|
|
|
async def render_messages_async(
|
|
self,
|
|
messages: list["ChatCompletionMessageParam"],
|
|
params: ChatParams,
|
|
) -> tuple[list["ConversationMessage"], DictPrompt]:
|
|
return self.render_messages(messages, params)
|
|
|
|
# Step 2: Tokenize prompts if necessary
|
|
def _tokenize_prompt(
|
|
self,
|
|
prompt: TextPrompt,
|
|
params: TokenizeParams,
|
|
) -> TokensPrompt:
|
|
tokenizer = self.get_tokenizer()
|
|
prompt_token_ids = tokenizer.encode(
|
|
prompt["prompt"],
|
|
**params.get_encode_kwargs(),
|
|
)
|
|
|
|
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
|
|
|
async def _tokenize_prompt_async(
|
|
self,
|
|
prompt: TextPrompt,
|
|
params: TokenizeParams,
|
|
) -> TokensPrompt:
|
|
tokenizer = self.get_async_tokenizer()
|
|
prompt_token_ids = await tokenizer.encode(
|
|
prompt["prompt"],
|
|
**params.get_encode_kwargs(),
|
|
)
|
|
|
|
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
|
|
|
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
|
|
tokenizer = self.get_tokenizer()
|
|
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
|
|
|
|
return prompt
|
|
|
|
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
|
|
tokenizer = self.get_async_tokenizer()
|
|
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
|
|
|
|
return prompt
|
|
|
|
@overload
|
|
def _tokenize_singleton_prompt(
|
|
self,
|
|
prompt: TextPrompt | TokensPrompt,
|
|
params: TokenizeParams,
|
|
) -> TokensPrompt: ...
|
|
|
|
@overload
|
|
def _tokenize_singleton_prompt( # type: ignore[misc]
|
|
self,
|
|
prompt: EmbedsPrompt,
|
|
params: TokenizeParams,
|
|
) -> EmbedsPrompt: ...
|
|
|
|
def _tokenize_singleton_prompt(
|
|
self,
|
|
prompt: SingletonDictPrompt,
|
|
params: TokenizeParams,
|
|
) -> SingletonTokPrompt:
|
|
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
|
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
|
prompt = self._tokenize_prompt(prompt, params)
|
|
|
|
if params.needs_detokenization and "prompt" not in prompt:
|
|
if "prompt_token_ids" not in prompt:
|
|
raise RuntimeError("Cannot run detokenization on embeddings")
|
|
|
|
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
|
|
|
|
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
|
|
|
@overload
|
|
async def _tokenize_singleton_prompt_async(
|
|
self,
|
|
prompt: TextPrompt | TokensPrompt,
|
|
params: TokenizeParams,
|
|
) -> TokensPrompt: ...
|
|
|
|
@overload
|
|
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
|
|
self,
|
|
prompt: EmbedsPrompt,
|
|
params: TokenizeParams,
|
|
) -> EmbedsPrompt: ...
|
|
|
|
async def _tokenize_singleton_prompt_async(
|
|
self,
|
|
prompt: SingletonDictPrompt,
|
|
params: TokenizeParams,
|
|
) -> SingletonTokPrompt:
|
|
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
|
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
|
prompt = await self._tokenize_prompt_async(prompt, params)
|
|
|
|
if params.needs_detokenization and "prompt" not in prompt:
|
|
if "prompt_token_ids" not in prompt:
|
|
raise RuntimeError("Cannot run detokenization on embeddings")
|
|
|
|
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
|
|
|
|
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
|
|
|
def _tokenize_enc_dec_prompt(
|
|
self,
|
|
prompt: EncoderDecoderDictPrompt,
|
|
params: TokenizeParams,
|
|
) -> EncoderDecoderTokPrompt:
|
|
enc_prompt, dec_prompt = (
|
|
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
|
|
(
|
|
None
|
|
if prompt["decoder_prompt"] is None
|
|
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
|
|
),
|
|
)
|
|
|
|
return EncoderDecoderTokPrompt(
|
|
encoder_prompt=enc_prompt,
|
|
decoder_prompt=dec_prompt,
|
|
)
|
|
|
|
async def _tokenize_enc_dec_prompt_async(
|
|
self,
|
|
prompt: EncoderDecoderDictPrompt,
|
|
params: TokenizeParams,
|
|
) -> EncoderDecoderTokPrompt:
|
|
enc_prompt, dec_prompt = await asyncio.gather(
|
|
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
|
|
(
|
|
asyncio.sleep(0)
|
|
if prompt["decoder_prompt"] is None
|
|
else self._tokenize_singleton_prompt_async(
|
|
prompt["decoder_prompt"], params
|
|
)
|
|
),
|
|
)
|
|
|
|
return EncoderDecoderTokPrompt(
|
|
encoder_prompt=enc_prompt,
|
|
decoder_prompt=dec_prompt,
|
|
)
|
|
|
|
def tokenize_prompt(
|
|
self,
|
|
prompt: DictPrompt,
|
|
params: TokenizeParams,
|
|
) -> TokPrompt:
|
|
if "encoder_prompt" in prompt:
|
|
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
|
|
|
|
return self._tokenize_singleton_prompt(prompt, params)
|
|
|
|
def tokenize_prompts(
|
|
self,
|
|
prompts: Sequence[DictPrompt],
|
|
params: TokenizeParams,
|
|
) -> list[TokPrompt]:
|
|
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
|
|
|
|
async def tokenize_prompt_async(
|
|
self,
|
|
prompt: DictPrompt,
|
|
params: TokenizeParams,
|
|
) -> TokPrompt:
|
|
if "encoder_prompt" in prompt:
|
|
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
|
|
|
|
return await self._tokenize_singleton_prompt_async(prompt, params)
|
|
|
|
async def tokenize_prompts_async(
|
|
self,
|
|
prompts: Sequence[DictPrompt],
|
|
params: TokenizeParams,
|
|
) -> list[TokPrompt]:
|
|
return await asyncio.gather(
|
|
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
|
|
)
|
|
|
|
# Step 3: Add extra keys to the prompts
|
|
def _apply_prompt_extras(
|
|
self,
|
|
prompts: Sequence[TokPrompt],
|
|
prompt_extras: dict[str, Any] | None,
|
|
):
|
|
if not prompt_extras:
|
|
return
|
|
|
|
for prompt in prompts:
|
|
target_prompt = extract_target_prompt(self.model_config, prompt)
|
|
target_prompt.update(prompt_extras) # type: ignore[arg-type]
|
|
|
|
# Step 4: Convert to engine inputs
|
|
def _validate_mm_uuids(
|
|
self,
|
|
mm_data: "MultiModalDataDict",
|
|
mm_data_items: "MultiModalDataItems",
|
|
mm_uuid_items: "MultiModalUUIDItems",
|
|
) -> None:
|
|
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in
|
|
# `mm_data_items`
|
|
modalities = mm_data.keys() | mm_uuid_items.keys()
|
|
|
|
for modality in modalities:
|
|
data_items = mm_data_items.get(modality)
|
|
uuid_items = mm_uuid_items.get(modality)
|
|
|
|
if data_items is None:
|
|
if uuid_items is None:
|
|
raise ValueError(
|
|
f"multi_modal_data[{modality!r}] is empty but "
|
|
f"multi_modal_uuids[{modality!r}] is missing."
|
|
)
|
|
|
|
elif uuid_items is not None:
|
|
if len(data_items) != len(uuid_items):
|
|
raise ValueError(
|
|
f"If given, multi_modal_uuids[{modality!r}] must have "
|
|
f"same length as multi_modal_data[{modality!r}], but "
|
|
f"got {len(uuid_items)} vs {len(data_items)}."
|
|
)
|
|
|
|
for i, item in enumerate(data_items):
|
|
if item is None and uuid_items[i] is None:
|
|
raise ValueError(
|
|
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
|
f"multi_modal_uuids[{modality!r}][{i}] is missing."
|
|
)
|
|
|
|
def _process_mm_uuids(
|
|
self,
|
|
mm_data: "MultiModalDataDict",
|
|
mm_data_items: "MultiModalDataItems",
|
|
mm_uuid_items: "MultiModalUUIDItems",
|
|
mm_req_id: str,
|
|
):
|
|
model_config = self.model_config
|
|
|
|
# NOTE: When users explicitly turn off BOTH prefix caching and input
|
|
# processing caching, no multimodal features or embeddings will be
|
|
# reused across requests, therefore identifying multimodal data items
|
|
# by their content is no longer necessary, and we create uuids with
|
|
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
|
|
if (
|
|
model_config.multimodal_config
|
|
and model_config.multimodal_config.mm_processor_cache_gb == 0
|
|
and not self.config.cache_config.enable_prefix_caching
|
|
):
|
|
mm_uuid_items = {
|
|
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
|
|
for modality, data_count in mm_data_items.get_all_counts().items()
|
|
}
|
|
|
|
self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
|
|
|
|
return mm_uuid_items
|
|
|
|
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
|
|
def _process_multimodal(
|
|
self,
|
|
prompt: list[int] | str,
|
|
mm_data: "MultiModalDataDict",
|
|
mm_uuids: "MultiModalUUIDDict | None",
|
|
mm_processor_kwargs: Mapping[str, object] | None,
|
|
tokenization_kwargs: dict[str, Any] | None,
|
|
) -> "MultiModalInputs":
|
|
from vllm.multimodal.parse import parse_mm_uuids
|
|
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
|
|
|
|
mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
|
|
|
|
mm_processor = self.get_mm_processor()
|
|
|
|
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
|
|
mm_uuid_items = parse_mm_uuids(mm_uuids)
|
|
|
|
mm_uuid_items = self._process_mm_uuids(
|
|
mm_data, mm_data_items, mm_uuid_items, mm_req_id
|
|
)
|
|
|
|
mm_processor_inputs = MMProcessorInputs(
|
|
prompt,
|
|
mm_data_items,
|
|
mm_uuid_items,
|
|
hf_processor_mm_kwargs=mm_processor_kwargs or {},
|
|
tokenization_kwargs=tokenization_kwargs or {},
|
|
)
|
|
mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
|
|
|
|
with set_default_torch_num_threads():
|
|
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
|
|
|
|
self.update_mm_cache_stats()
|
|
|
|
return mm_inputs
|
|
|
|
def _process_tokens(
|
|
self,
|
|
prompt: TokensPrompt,
|
|
) -> "TokenInputs | MultiModalInputs":
|
|
prompt_token_ids = prompt["prompt_token_ids"]
|
|
|
|
inputs: TokenInputs | MultiModalInputs
|
|
if multi_modal_data := prompt.get("multi_modal_data"):
|
|
inputs = self._process_multimodal(
|
|
prompt_token_ids,
|
|
multi_modal_data,
|
|
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
|
|
tokenization_kwargs=None, # Tokenization already done in Step 2
|
|
mm_uuids=prompt.get("multi_modal_uuids"),
|
|
)
|
|
else:
|
|
inputs = token_inputs(prompt_token_ids)
|
|
|
|
if prompt_text := prompt.get("prompt"):
|
|
inputs["prompt"] = prompt_text
|
|
if cache_salt := prompt.get("cache_salt"):
|
|
inputs["cache_salt"] = cache_salt
|
|
|
|
return inputs
|
|
|
|
def _process_embeds(
|
|
self,
|
|
prompt: EmbedsPrompt,
|
|
) -> EmbedsInputs:
|
|
if not self.model_config.enable_prompt_embeds:
|
|
raise ValueError(
|
|
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
|
|
)
|
|
|
|
prompt_embeds = prompt["prompt_embeds"]
|
|
|
|
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
|
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
|
# we can unambiguously process the intent by squeezing the batch
|
|
# dimension.
|
|
if prompt_embeds.ndim == 3:
|
|
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
|
|
|
if prompt_embeds.ndim != 2:
|
|
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
|
|
|
|
# Tensors must be on CPU for serialization between processes
|
|
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
|
# hidden device transfer in the critical path of generation.
|
|
prompt_embeds = prompt_embeds.cpu()
|
|
|
|
return embeds_inputs(
|
|
prompt_embeds=prompt_embeds,
|
|
cache_salt=prompt.get("cache_salt"),
|
|
)
|
|
|
|
def _process_singleton(
|
|
self,
|
|
prompt: SingletonTokPrompt,
|
|
) -> SingletonInputs:
|
|
if "prompt_embeds" in prompt:
|
|
return self._process_embeds(prompt) # type: ignore[arg-type]
|
|
|
|
return self._process_tokens(prompt) # type: ignore[arg-type]
|
|
|
|
def _process_enc_dec(
|
|
self,
|
|
prompt: EncoderDecoderTokPrompt,
|
|
) -> EncoderDecoderInputs:
|
|
enc_prompt = prompt["encoder_prompt"]
|
|
dec_prompt = prompt["decoder_prompt"]
|
|
|
|
return build_enc_dec_inputs(
|
|
encoder_inputs=self._process_singleton(enc_prompt),
|
|
decoder_inputs=(
|
|
None if dec_prompt is None else self._process_singleton(dec_prompt)
|
|
),
|
|
decoder_start_token_id=self.get_dec_start_token_id(),
|
|
)
|
|
|
|
def process_for_engine(
|
|
self, prompt: TokPrompt, arrival_time: float
|
|
) -> ProcessorInputs:
|
|
engine_prompt: ProcessorInputs
|
|
if "encoder_prompt" in prompt:
|
|
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
|
|
else:
|
|
engine_prompt = self._process_singleton(prompt)
|
|
|
|
engine_prompt["arrival_time"] = arrival_time
|
|
|
|
return engine_prompt
|
|
|
|
# Top-level methods
|
|
def render_cmpl(
|
|
self,
|
|
prompts: Sequence[DictPrompt | bytes],
|
|
tok_params: TokenizeParams | None = None,
|
|
*,
|
|
prompt_extras: dict[str, Any] | None = None,
|
|
):
|
|
arrival_time = time.time()
|
|
|
|
if tok_params is None:
|
|
tok_params = self.default_cmpl_tok_params
|
|
|
|
dict_prompts = self.render_prompts(prompts)
|
|
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
|
|
|
|
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
|
|
|
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
|
|
|
async def render_cmpl_async(
|
|
self,
|
|
prompts: Sequence[DictPrompt | bytes],
|
|
tok_params: TokenizeParams | None = None,
|
|
*,
|
|
prompt_extras: dict[str, Any] | None = None,
|
|
):
|
|
arrival_time = time.time()
|
|
|
|
if tok_params is None:
|
|
tok_params = self.default_cmpl_tok_params
|
|
|
|
dict_prompts = await self.render_prompts_async(prompts)
|
|
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
|
|
|
|
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
|
|
|
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
|
|
|
def render_chat(
|
|
self,
|
|
conversations: Sequence[list["ChatCompletionMessageParam"]],
|
|
chat_params: ChatParams,
|
|
tok_params: TokenizeParams | None = None,
|
|
*,
|
|
prompt_extras: dict[str, Any] | None = None,
|
|
):
|
|
arrival_time = time.time()
|
|
|
|
if tok_params is None:
|
|
tok_params = self.default_chat_tok_params
|
|
|
|
rendered = [
|
|
self.render_messages(conversation, chat_params)
|
|
for conversation in conversations
|
|
]
|
|
|
|
out_conversations = list[list["ConversationMessage"]]()
|
|
dict_prompts = list[DictPrompt]()
|
|
for conv, prompt in rendered:
|
|
out_conversations.append(conv)
|
|
dict_prompts.append(prompt)
|
|
|
|
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
|
|
|
|
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
|
|
|
eng_prompts = [
|
|
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
|
]
|
|
|
|
return out_conversations, eng_prompts
|
|
|
|
async def render_chat_async(
|
|
self,
|
|
conversations: Sequence[list["ChatCompletionMessageParam"]],
|
|
chat_params: ChatParams,
|
|
tok_params: TokenizeParams | None = None,
|
|
*,
|
|
prompt_extras: dict[str, Any] | None = None,
|
|
):
|
|
arrival_time = time.time()
|
|
|
|
if tok_params is None:
|
|
tok_params = self.default_chat_tok_params
|
|
|
|
rendered = [
|
|
self.render_messages_async(conversation, chat_params)
|
|
for conversation in conversations
|
|
]
|
|
|
|
out_conversations = list[list["ConversationMessage"]]()
|
|
dict_prompts = list[DictPrompt]()
|
|
for conv, prompt in await asyncio.gather(*rendered):
|
|
out_conversations.append(conv)
|
|
dict_prompts.append(prompt)
|
|
|
|
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
|
|
|
|
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
|
|
|
eng_prompts = [
|
|
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
|
]
|
|
|
|
return out_conversations, eng_prompts
|