[Refactor] Call renderer for online IO processor request (#34490)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -500,7 +500,7 @@ class LLM:
|
||||
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
|
||||
engine_prompt
|
||||
for prompt, param in zip(seq_prompts, seq_params)
|
||||
for engine_prompt in self._preprocess_completion(
|
||||
for engine_prompt in self._preprocess_cmpl(
|
||||
[prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
@@ -509,7 +509,7 @@ class LLM:
|
||||
)
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_completion(
|
||||
engine_prompts = self._preprocess_cmpl(
|
||||
seq_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
@@ -889,7 +889,7 @@ class LLM:
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
).with_kwargs(tokenization_kwargs)
|
||||
|
||||
def _preprocess_completion(
|
||||
def _preprocess_cmpl(
|
||||
self,
|
||||
prompts: Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
@@ -901,7 +901,7 @@ class LLM:
|
||||
Refer to [LLM.generate][] for a complete description of the arguments.
|
||||
|
||||
Returns:
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
A list of `TokPrompt` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
renderer = self.renderer
|
||||
@@ -943,7 +943,7 @@ class LLM:
|
||||
Refer to [LLM.chat][] for a complete description of the arguments.
|
||||
|
||||
Returns:
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
A list of `TokPrompt` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
renderer = self.renderer
|
||||
@@ -1823,11 +1823,11 @@ class LLM:
|
||||
if any(param.truncate_prompt_tokens is not None for param in seq_params):
|
||||
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
|
||||
# Then, move the code from the `else` block to the top and let
|
||||
# `self._preprocess_completion` handle prompt normalization
|
||||
# `self._preprocess_cmpl` handle prompt normalization
|
||||
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
|
||||
engine_prompt
|
||||
for prompt, param in zip(seq_prompts, seq_params)
|
||||
for engine_prompt in self._preprocess_completion(
|
||||
for engine_prompt in self._preprocess_cmpl(
|
||||
[prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
@@ -1836,7 +1836,7 @@ class LLM:
|
||||
)
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_completion(
|
||||
engine_prompts = self._preprocess_cmpl(
|
||||
seq_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||||
@@ -959,15 +959,22 @@ class OpenAIServing:
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
return await self._preprocess_cmpl(request, prompts)
|
||||
|
||||
async def _preprocess_cmpl(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompts: Sequence[PromptType | bytes],
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
|
||||
@@ -100,6 +100,18 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
|
||||
data: T
|
||||
task: PoolingTask = "plugin"
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
request_id: str | None = None
|
||||
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Final, Literal, cast
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
@@ -108,7 +108,10 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
raw_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id
|
||||
)
|
||||
engine_prompts = prompt_to_seq(raw_prompts)
|
||||
engine_prompts = await self._preprocess_cmpl(
|
||||
request,
|
||||
prompt_to_seq(raw_prompts),
|
||||
)
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
@@ -146,12 +149,11 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
pooling_params = self.io_processor.merge_pooling_params()
|
||||
if pooling_params.task is None:
|
||||
pooling_params.task = "plugin"
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
else:
|
||||
pooling_params = request.to_pooling_params() # type: ignore
|
||||
tok_params = request.build_tok_params(self.model_config) # type: ignore
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
Reference in New Issue
Block a user