diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f54d9121c..9474c543e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -500,7 +500,7 @@ class LLM: engine_prompts: Sequence[DictPrompt | TokPrompt] = [ engine_prompt for prompt, param in zip(seq_prompts, seq_params) - for engine_prompt in self._preprocess_completion( + for engine_prompt in self._preprocess_cmpl( [prompt], tokenization_kwargs=merge_kwargs( tokenization_kwargs, @@ -509,7 +509,7 @@ class LLM: ) ] else: - engine_prompts = self._preprocess_completion( + engine_prompts = self._preprocess_cmpl( seq_prompts, tokenization_kwargs=tokenization_kwargs, ) @@ -889,7 +889,7 @@ class LLM: add_special_tokens=not model_config.is_encoder_decoder, ).with_kwargs(tokenization_kwargs) - def _preprocess_completion( + def _preprocess_cmpl( self, prompts: Sequence[PromptType], tokenization_kwargs: dict[str, Any] | None = None, @@ -901,7 +901,7 @@ class LLM: Refer to [LLM.generate][] for a complete description of the arguments. Returns: - A list of `TokensPrompts` objects containing the tokenized prompt + A list of `TokPrompt` objects containing the tokenized prompt after chat template interpolation, and the raw multi-modal inputs. """ renderer = self.renderer @@ -943,7 +943,7 @@ class LLM: Refer to [LLM.chat][] for a complete description of the arguments. Returns: - A list of `TokensPrompts` objects containing the tokenized prompt + A list of `TokPrompt` objects containing the tokenized prompt after chat template interpolation, and the raw multi-modal inputs. """ renderer = self.renderer @@ -1823,11 +1823,11 @@ class LLM: if any(param.truncate_prompt_tokens is not None for param in seq_params): # TODO: Remove this after deprecating `param.truncate_prompt_tokens` # Then, move the code from the `else` block to the top and let - # `self._preprocess_completion` handle prompt normalization + # `self._preprocess_cmpl` handle prompt normalization engine_prompts: Sequence[DictPrompt | TokPrompt] = [ engine_prompt for prompt, param in zip(seq_prompts, seq_params) - for engine_prompt in self._preprocess_completion( + for engine_prompt in self._preprocess_cmpl( [prompt], tokenization_kwargs=merge_kwargs( tokenization_kwargs, @@ -1836,7 +1836,7 @@ class LLM: ) ] else: - engine_prompts = self._preprocess_completion( + engine_prompts = self._preprocess_cmpl( seq_prompts, tokenization_kwargs=tokenization_kwargs, ) diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index d39decaa7..1484fca5b 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -5,7 +5,7 @@ import json import sys import time import traceback -from collections.abc import AsyncGenerator, Callable, Mapping +from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from dataclasses import dataclass, field from http import HTTPStatus from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar @@ -959,15 +959,22 @@ class OpenAIServing: prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_embeds: bytes | list[bytes] | None, ) -> list[TokPrompt]: - renderer = self.renderer - model_config = self.model_config - prompts = list[SingletonPrompt | bytes]() if prompt_embeds is not None: # embeds take higher priority prompts.extend(prompt_to_seq(prompt_embeds)) if prompt_input is not None: prompts.extend(prompt_to_seq(prompt_input)) + return await self._preprocess_cmpl(request, prompts) + + async def _preprocess_cmpl( + self, + request: RendererRequest, + prompts: Sequence[PromptType | bytes], + ) -> list[TokPrompt]: + renderer = self.renderer + model_config = self.model_config + parsed_prompts = [ ( prompt diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index 6a5a743cd..a8c1c59ff 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -100,6 +100,18 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic data: T task: PoolingTask = "plugin" + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: + encoder_config = model_config.encoder_config or {} + + return TokenizeParams( + max_total_tokens=model_config.max_model_len, + max_output_tokens=0, + truncate_prompt_tokens=self.truncate_prompt_tokens, + do_lower_case=encoder_config.get("do_lower_case", False), + add_special_tokens=not model_config.is_encoder_decoder, + max_total_tokens_param="max_model_len", + ) + class IOProcessorResponse(OpenAIBaseModel, Generic[T]): request_id: str | None = None diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 5c5d649f6..16a9722c0 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -6,7 +6,7 @@ import json import time from collections.abc import AsyncGenerator, Callable, Sequence from functools import partial -from typing import Any, Final, Literal, cast +from typing import Final, Literal, cast import jinja2 from fastapi import Request @@ -108,7 +108,10 @@ class OpenAIServingPooling(OpenAIServing): raw_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id ) - engine_prompts = prompt_to_seq(raw_prompts) + engine_prompts = await self._preprocess_cmpl( + request, + prompt_to_seq(raw_prompts), + ) elif isinstance(request, PoolingChatRequest): error_check_ret = self._validate_chat_template( request_chat_template=request.chat_template, @@ -146,12 +149,11 @@ class OpenAIServingPooling(OpenAIServing): pooling_params = self.io_processor.merge_pooling_params() if pooling_params.task is None: pooling_params.task = "plugin" - - tokenization_kwargs: dict[str, Any] = {} else: pooling_params = request.to_pooling_params() # type: ignore - tok_params = request.build_tok_params(self.model_config) # type: ignore - tokenization_kwargs = tok_params.get_encode_kwargs() + + tok_params = request.build_tok_params(self.model_config) + tokenization_kwargs = tok_params.get_encode_kwargs() for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}"