[Refactor] Call renderer for online IO processor request (#34490)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2026-02-13 14:48:45 +08:00
committed by GitHub
parent eea3024f43
commit ec090c2429
4 changed files with 39 additions and 18 deletions

View File

@@ -500,7 +500,7 @@ class LLM:
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
@@ -509,7 +509,7 @@ class LLM:
)
]
else:
engine_prompts = self._preprocess_completion(
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
@@ -889,7 +889,7 @@ class LLM:
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _preprocess_completion(
def _preprocess_cmpl(
self,
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
@@ -901,7 +901,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
renderer = self.renderer
@@ -943,7 +943,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
renderer = self.renderer
@@ -1823,11 +1823,11 @@ class LLM:
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
# `self._preprocess_cmpl` handle prompt normalization
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
@@ -1836,7 +1836,7 @@ class LLM:
)
]
else:
engine_prompts = self._preprocess_completion(
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)

View File

@@ -5,7 +5,7 @@ import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Mapping
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
@@ -959,15 +959,22 @@ class OpenAIServing:
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts)
async def _preprocess_cmpl(
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
parsed_prompts = [
(
prompt

View File

@@ -100,6 +100,18 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
data: T
task: PoolingTask = "plugin"
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=not model_config.is_encoder_decoder,
max_total_tokens_param="max_model_len",
)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None

View File

@@ -6,7 +6,7 @@ import json
import time
from collections.abc import AsyncGenerator, Callable, Sequence
from functools import partial
from typing import Any, Final, Literal, cast
from typing import Final, Literal, cast
import jinja2
from fastapi import Request
@@ -108,7 +108,10 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
engine_prompts = prompt_to_seq(raw_prompts)
engine_prompts = await self._preprocess_cmpl(
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
@@ -146,12 +149,11 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
tokenization_kwargs: dict[str, Any] = {}
else:
pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"