2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
import time
|
2025-03-03 01:34:51 +00:00
|
|
|
from collections.abc import Mapping
|
|
|
|
|
from typing import Optional, Union
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-03-10 19:06:19 -04:00
|
|
|
import vllm.platforms
|
2025-03-07 10:19:11 -05:00
|
|
|
from vllm.config import VllmConfig
|
2024-11-13 20:39:03 +08:00
|
|
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
|
|
|
|
PromptType, SingletonInputsAdapter)
|
|
|
|
|
from vllm.inputs.parse import is_encoder_decoder_inputs
|
2024-11-11 18:05:38 -05:00
|
|
|
from vllm.inputs.preprocess import InputPreprocessor
|
|
|
|
|
from vllm.lora.request import LoRARequest
|
2025-03-17 14:42:06 +08:00
|
|
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
|
|
|
|
MultiModalRegistry)
|
|
|
|
|
from vllm.multimodal.inputs import PlaceholderRange
|
2025-01-06 11:58:16 -08:00
|
|
|
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
2024-11-11 18:05:38 -05:00
|
|
|
from vllm.pooling_params import PoolingParams
|
|
|
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
|
|
|
from vllm.sampling_params import SamplingParams
|
2024-11-13 20:39:03 +08:00
|
|
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
2024-12-28 15:51:57 -05:00
|
|
|
from vllm.v1.engine import EngineCoreRequest
|
2025-03-07 10:19:11 -05:00
|
|
|
from vllm.v1.structured_output.utils import validate_structured_output_request
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class Processor:
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
2025-03-07 10:19:11 -05:00
|
|
|
vllm_config: VllmConfig,
|
2024-11-13 20:39:03 +08:00
|
|
|
tokenizer: BaseTokenizerGroup,
|
2024-11-11 18:05:38 -05:00
|
|
|
input_registry: InputRegistry = INPUT_REGISTRY,
|
2024-11-13 20:39:03 +08:00
|
|
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
2024-11-11 18:05:38 -05:00
|
|
|
):
|
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
self.vllm_config = vllm_config
|
|
|
|
|
self.model_config = vllm_config.model_config
|
|
|
|
|
self.cache_config = vllm_config.cache_config
|
|
|
|
|
self.lora_config = vllm_config.lora_config
|
|
|
|
|
self.decoding_config = vllm_config.decoding_config
|
2024-11-11 18:05:38 -05:00
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
self.generation_config_fields = (
|
|
|
|
|
self.model_config.try_get_generation_config())
|
|
|
|
|
self.input_preprocessor = InputPreprocessor(self.model_config,
|
2024-11-13 20:39:03 +08:00
|
|
|
self.tokenizer,
|
|
|
|
|
mm_registry)
|
2024-12-11 19:55:30 -05:00
|
|
|
|
|
|
|
|
# Multi-modal hasher (for images)
|
2025-03-07 10:19:11 -05:00
|
|
|
self.use_hash = (
|
|
|
|
|
not self.model_config.disable_mm_preprocessor_cache) or \
|
|
|
|
|
self.cache_config.enable_prefix_caching
|
2024-12-03 05:33:10 -05:00
|
|
|
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
def _validate_logprobs(
|
|
|
|
|
self,
|
2025-03-05 14:18:55 +00:00
|
|
|
params: SamplingParams,
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
) -> None:
|
|
|
|
|
max_logprobs = self.model_config.max_logprobs
|
|
|
|
|
# Validate sample logprobs.
|
|
|
|
|
if params.logprobs and params.logprobs > max_logprobs:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Requested sample logprobs of {params.logprobs}, "
|
|
|
|
|
f"which is greater than max allowed: {max_logprobs}")
|
|
|
|
|
|
|
|
|
|
# Validate prompt logprobs.
|
|
|
|
|
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
|
|
|
|
f"which is greater than max allowed: {max_logprobs}")
|
|
|
|
|
|
2025-03-05 14:18:55 +00:00
|
|
|
def _validate_sampling_params(
|
2025-02-21 22:13:05 -08:00
|
|
|
self,
|
2025-03-05 14:18:55 +00:00
|
|
|
params: SamplingParams,
|
2025-02-21 22:13:05 -08:00
|
|
|
) -> None:
|
2025-03-07 10:19:11 -05:00
|
|
|
self._validate_structured_output(params)
|
|
|
|
|
|
2025-02-21 22:13:05 -08:00
|
|
|
if params.allowed_token_ids is None:
|
|
|
|
|
return
|
2025-03-05 00:49:44 -08:00
|
|
|
if not params.allowed_token_ids:
|
|
|
|
|
raise ValueError("allowed_token_ids is not None and empty!")
|
|
|
|
|
vocab_size = self.model_config.get_vocab_size()
|
|
|
|
|
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
2025-02-21 22:13:05 -08:00
|
|
|
raise ValueError(
|
2025-03-05 00:49:44 -08:00
|
|
|
"allowed_token_ids contains out-of-vocab token id!")
|
2025-02-21 22:13:05 -08:00
|
|
|
|
2025-03-05 14:18:55 +00:00
|
|
|
def _validate_supported_sampling_params(
|
|
|
|
|
self,
|
|
|
|
|
params: SamplingParams,
|
|
|
|
|
) -> None:
|
2025-03-06 17:34:22 +01:00
|
|
|
# Best of not yet supported.
|
|
|
|
|
if params.best_of is not None and params.best_of > 1:
|
2025-03-10 17:36:21 +01:00
|
|
|
raise ValueError("vLLM V1 does not yet support best_of.")
|
2025-03-05 14:18:55 +00:00
|
|
|
# Logits processors not supported.
|
|
|
|
|
if params.logits_processors:
|
2025-03-10 17:36:21 +01:00
|
|
|
raise ValueError("vLLM V1 does not support per request "
|
2025-03-05 14:18:55 +00:00
|
|
|
"user provided logits processors.")
|
|
|
|
|
|
|
|
|
|
def _validate_params(
|
|
|
|
|
self,
|
|
|
|
|
params: Union[SamplingParams, PoolingParams],
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Validate supported SamplingParam.
|
|
|
|
|
Should raise ValueError if unsupported for API Server.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not isinstance(params, SamplingParams):
|
|
|
|
|
raise ValueError("V1 does not yet support Pooling models.")
|
|
|
|
|
|
|
|
|
|
self._validate_logprobs(params)
|
|
|
|
|
self._validate_sampling_params(params)
|
|
|
|
|
self._validate_supported_sampling_params(params)
|
|
|
|
|
|
|
|
|
|
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
|
|
|
|
if lora_request is not None and not self.lora_config:
|
|
|
|
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
|
|
|
|
"not enabled!")
|
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
def _validate_structured_output(self, params: SamplingParams) -> None:
|
|
|
|
|
if not params.guided_decoding or not self.decoding_config:
|
|
|
|
|
return
|
|
|
|
|
if self.decoding_config.guided_decoding_backend != "xgrammar":
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Only xgrammar structured output is supported in V1.")
|
|
|
|
|
if (params.guided_decoding.backend
|
|
|
|
|
and params.guided_decoding.backend != 'xgrammar'):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Only xgrammar structured output is supported in V1.")
|
|
|
|
|
if self.vllm_config.speculative_config:
|
|
|
|
|
raise ValueError("Structured output is not supported with "
|
|
|
|
|
"speculative decoding.")
|
2025-03-10 19:06:19 -04:00
|
|
|
if vllm.platforms.current_platform.is_tpu():
|
|
|
|
|
raise ValueError("Structured output is not supported on TPU.")
|
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
validate_structured_output_request(params)
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
def process_inputs(
|
|
|
|
|
self,
|
|
|
|
|
request_id: str,
|
|
|
|
|
prompt: PromptType,
|
|
|
|
|
params: Union[SamplingParams, PoolingParams],
|
2024-12-14 17:54:04 +00:00
|
|
|
arrival_time: Optional[float] = None,
|
2024-11-11 18:05:38 -05:00
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
|
|
trace_headers: Optional[Mapping[str, str]] = None,
|
|
|
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
|
|
|
priority: int = 0,
|
2024-12-28 15:51:57 -05:00
|
|
|
) -> EngineCoreRequest:
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2024-12-11 21:36:27 +08:00
|
|
|
# TODO(woosuk): Support pooling models.
|
2024-11-11 18:05:38 -05:00
|
|
|
# TODO(woosuk): Support encoder-decoder models.
|
|
|
|
|
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
self._validate_lora(lora_request)
|
2025-03-05 14:18:55 +00:00
|
|
|
self._validate_params(params)
|
|
|
|
|
if priority != 0:
|
|
|
|
|
raise ValueError("V1 does not support priority yet.")
|
|
|
|
|
if trace_headers is not None:
|
|
|
|
|
raise ValueError("V1 does not support tracing yet.")
|
|
|
|
|
if prompt_adapter_request is not None:
|
|
|
|
|
raise ValueError("V1 does not support prompt_adapter_request.")
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
if arrival_time is None:
|
|
|
|
|
arrival_time = time.time()
|
|
|
|
|
|
2025-02-13 03:43:24 -08:00
|
|
|
# Process inputs, which includes:
|
|
|
|
|
# 1. Tokenize text prompt, with LoRA request if one exists.
|
|
|
|
|
# 2. For multimodal models with a merged preprocessor, preprocess
|
|
|
|
|
# multimodal data and expand prompt token ids accordingly.
|
|
|
|
|
# 3. Apply prompt adapter to prompt token ids if one exists.
|
2025-03-17 14:42:06 +08:00
|
|
|
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
2024-11-11 18:05:38 -05:00
|
|
|
prompt,
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
lora_request=lora_request,
|
|
|
|
|
prompt_adapter_request=prompt_adapter_request,
|
2025-03-04 23:37:16 -08:00
|
|
|
return_mm_hashes=self.use_hash,
|
2024-11-11 18:05:38 -05:00
|
|
|
)
|
2025-02-13 03:43:24 -08:00
|
|
|
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
|
|
|
|
|
2025-03-15 01:02:20 -04:00
|
|
|
self._validate_model_inputs(processed_inputs, lora_request)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2024-11-13 20:39:03 +08:00
|
|
|
if is_encoder_decoder_inputs(processed_inputs):
|
|
|
|
|
decoder_inputs = SingletonInputsAdapter(
|
|
|
|
|
processed_inputs["decoder"])
|
|
|
|
|
encoder_inputs = SingletonInputsAdapter(
|
|
|
|
|
processed_inputs["encoder"])
|
|
|
|
|
else:
|
|
|
|
|
decoder_inputs = SingletonInputsAdapter(processed_inputs)
|
|
|
|
|
encoder_inputs = None
|
|
|
|
|
|
|
|
|
|
# TODO: Impl encoder-decoder
|
|
|
|
|
if encoder_inputs is not None:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
assert isinstance(params, SamplingParams)
|
2025-03-15 01:02:20 -04:00
|
|
|
# TODO: can we avoid cloning here in multiproc case?
|
2024-11-11 18:05:38 -05:00
|
|
|
sampling_params = params.clone()
|
2025-03-15 01:02:20 -04:00
|
|
|
# If unset max tokens, then generate up to the max_model_len.
|
|
|
|
|
if sampling_params.max_tokens is None:
|
|
|
|
|
sampling_params.max_tokens = (self.model_config.max_model_len -
|
|
|
|
|
len(decoder_inputs.prompt_token_ids))
|
2024-11-11 18:05:38 -05:00
|
|
|
sampling_params.update_from_generation_config(
|
|
|
|
|
self.generation_config_fields, eos_token_id)
|
2025-03-08 14:50:26 -08:00
|
|
|
sampling_params.update_from_tokenizer(
|
|
|
|
|
self.tokenizer.get_lora_tokenizer(lora_request))
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-01-06 11:58:16 -08:00
|
|
|
# Multimodal related.
|
2025-03-17 14:42:06 +08:00
|
|
|
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
|
|
|
|
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
|
|
|
|
sorted_mm_hashes: Optional[list[str]] = None
|
|
|
|
|
if (decoder_mm_inputs := decoder_inputs.multi_modal_data):
|
|
|
|
|
assert isinstance(decoder_mm_inputs, MultiModalKwargs)
|
2025-01-06 11:58:16 -08:00
|
|
|
|
2025-03-17 14:42:06 +08:00
|
|
|
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
2025-01-02 17:00:00 +08:00
|
|
|
# contains the kwargs for all items from all modalities.
|
|
|
|
|
# This code separates them so that there is one set of kwargs
|
|
|
|
|
# per item per modality.
|
2025-03-17 14:42:06 +08:00
|
|
|
individual_mm_inputs = [
|
2025-01-02 17:00:00 +08:00
|
|
|
MultiModalKwargs.from_items([item])
|
2025-03-17 14:42:06 +08:00
|
|
|
for modality in decoder_mm_inputs.modalities
|
|
|
|
|
for item in decoder_mm_inputs.get_items(modality)
|
2025-01-02 17:00:00 +08:00
|
|
|
]
|
2024-12-11 19:55:30 -05:00
|
|
|
|
2025-01-06 11:58:16 -08:00
|
|
|
# Merge and flatten multimodal placeholders, hashes and inputs
|
|
|
|
|
# from dictionaries to lists, and sort them by each item's position
|
|
|
|
|
# in the input sequence.
|
|
|
|
|
# NOTE: interleaved modalities are not supported.
|
|
|
|
|
(
|
|
|
|
|
sorted_modalities,
|
|
|
|
|
sorted_mm_positions,
|
|
|
|
|
sorted_mm_hashes,
|
|
|
|
|
) = merge_and_sort_multimodal_metadata(
|
2025-03-17 14:42:06 +08:00
|
|
|
decoder_inputs.multi_modal_placeholders,
|
|
|
|
|
decoder_inputs.multi_modal_hashes if self.use_hash else None,
|
2025-01-02 17:00:00 +08:00
|
|
|
)
|
2024-12-03 05:33:10 -05:00
|
|
|
|
2025-01-06 11:58:16 -08:00
|
|
|
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
|
2025-03-17 14:42:06 +08:00
|
|
|
# modalities involved.
|
|
|
|
|
if len(sorted_modalities) > 1:
|
2025-01-06 11:58:16 -08:00
|
|
|
modality_order_dict = {
|
|
|
|
|
modality: order
|
|
|
|
|
for order, modality in enumerate(sorted_modalities)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Sanity check to make sure each multimodal input has only one
|
|
|
|
|
# modality key.
|
2025-03-17 14:42:06 +08:00
|
|
|
for mm_input in individual_mm_inputs:
|
2025-01-06 11:58:16 -08:00
|
|
|
assert len(mm_input.modalities) == 1
|
|
|
|
|
|
2025-03-17 14:42:06 +08:00
|
|
|
# Sort MultiModalKwargs to match sorted_mm_positions
|
|
|
|
|
sorted_mm_inputs = sorted(
|
|
|
|
|
individual_mm_inputs,
|
2025-01-06 11:58:16 -08:00
|
|
|
key=lambda mm_input: modality_order_dict[list(
|
|
|
|
|
mm_input.modalities)[0]])
|
2025-03-17 14:42:06 +08:00
|
|
|
else:
|
|
|
|
|
sorted_mm_inputs = individual_mm_inputs
|
2025-01-06 11:58:16 -08:00
|
|
|
|
2024-12-28 15:51:57 -05:00
|
|
|
return EngineCoreRequest(
|
2025-01-06 11:58:16 -08:00
|
|
|
request_id=request_id,
|
|
|
|
|
prompt=decoder_inputs.prompt,
|
|
|
|
|
prompt_token_ids=decoder_inputs.prompt_token_ids,
|
|
|
|
|
mm_inputs=sorted_mm_inputs,
|
|
|
|
|
mm_hashes=sorted_mm_hashes,
|
|
|
|
|
mm_placeholders=sorted_mm_positions,
|
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
eos_token_id=eos_token_id,
|
|
|
|
|
arrival_time=arrival_time,
|
|
|
|
|
lora_request=lora_request,
|
2024-11-13 20:39:03 +08:00
|
|
|
)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-03-15 01:02:20 -04:00
|
|
|
def _validate_model_inputs(self,
|
|
|
|
|
inputs: ProcessorInputs,
|
|
|
|
|
lora_request: Optional[LoRARequest] = None):
|
2024-11-13 20:39:03 +08:00
|
|
|
if is_encoder_decoder_inputs(inputs):
|
|
|
|
|
# For encoder-decoder multimodal models, the max_prompt_len
|
|
|
|
|
# restricts the decoder prompt length
|
|
|
|
|
prompt_inputs = inputs["decoder" if self.model_config.
|
|
|
|
|
is_multimodal_model else "encoder"]
|
|
|
|
|
else:
|
|
|
|
|
prompt_inputs = inputs
|
|
|
|
|
|
|
|
|
|
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
if prompt_ids is None or len(prompt_ids) == 0:
|
|
|
|
|
raise ValueError("Prompt cannot be empty")
|
|
|
|
|
|
2025-03-15 01:02:20 -04:00
|
|
|
max_input_id = max(prompt_ids)
|
|
|
|
|
max_allowed = self.tokenizer.get_lora_tokenizer(
|
|
|
|
|
lora_request).max_token_id
|
|
|
|
|
if max_input_id > max_allowed:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Token id {} is out of vocabulary".format(max_input_id))
|
|
|
|
|
|
2025-01-31 21:32:04 -05:00
|
|
|
if len(prompt_ids) >= self.model_config.max_model_len:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Prompt length of {len(prompt_ids)} is longer than the "
|
|
|
|
|
f"maximum model length of {self.model_config.max_model_len}.")
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
if self.model_config.is_multimodal_model:
|
|
|
|
|
max_prompt_len = self.model_config.max_model_len
|
|
|
|
|
|
|
|
|
|
if len(prompt_ids) > max_prompt_len:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The prompt (total length {len(prompt_ids)}) is too long "
|
|
|
|
|
f"to fit into the model (context length {max_prompt_len}). "
|
|
|
|
|
"Make sure that `max_model_len` is no smaller than the "
|
|
|
|
|
"number of text tokens plus multimodal tokens. For image "
|
|
|
|
|
"inputs, the number of image tokens depends on the number "
|
|
|
|
|
"of images, and possibly their aspect ratios as well.")
|
|
|
|
|
|
2024-11-13 20:39:03 +08:00
|
|
|
# TODO: Find out how many placeholder tokens are there so we can
|
|
|
|
|
# check that chunked prefill does not truncate them
|
|
|
|
|
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|