[Misc] Reorganize inputs (#35182)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -63,7 +63,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import ProcessorInputs
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
@@ -177,7 +177,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def render_chat_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
|
||||
) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
|
||||
"""
|
||||
Validate the model and preprocess a chat completion request.
|
||||
|
||||
@@ -185,7 +185,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
engine-aware checks (LoRA model validation, engine health).
|
||||
|
||||
Returns:
|
||||
A tuple of (conversation, engine_prompts) on success,
|
||||
A tuple of (conversation, engine_inputs) on success,
|
||||
or an ErrorResponse on failure.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
@@ -231,7 +231,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
conversation, engine_prompts = result
|
||||
conversation, engine_inputs = result
|
||||
|
||||
request_id = (
|
||||
f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
|
||||
@@ -251,13 +251,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids
|
||||
for i, engine_input in enumerate(engine_inputs):
|
||||
prompt_token_ids = self._extract_prompt_components(engine_input).token_ids
|
||||
|
||||
# If we are creating sub requests for multiple prompts, ensure that they
|
||||
# have unique request ids.
|
||||
sub_request_id = (
|
||||
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
|
||||
request_id if len(engine_inputs) == 1 else f"{request_id}_{i}"
|
||||
)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
@@ -265,7 +265,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request.max_completion_tokens
|
||||
if request.max_completion_tokens is not None
|
||||
else request.max_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self._extract_prompt_len(engine_input),
|
||||
self.default_sampling_params,
|
||||
self.override_max_tokens,
|
||||
)
|
||||
@@ -283,7 +283,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -296,7 +296,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
prompt=engine_input,
|
||||
request_id=sub_request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
@@ -313,7 +313,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
reasoning_ended = None
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
|
||||
@@ -33,7 +33,7 @@ from vllm.entrypoints.openai.engine.serving import (
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import ProcessorInputs
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
@@ -82,7 +82,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def render_completion_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> list[ProcessorInputs] | ErrorResponse:
|
||||
) -> list[EngineInput] | ErrorResponse:
|
||||
"""
|
||||
Validate the model and preprocess a completion request.
|
||||
|
||||
@@ -90,8 +90,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
engine-aware checks (LoRA model validation, engine health).
|
||||
|
||||
Returns:
|
||||
A list of engine_prompts on success,
|
||||
or an ErrorResponse on failure.
|
||||
A list of engine_inputs on success, or an ErrorResponse on failure.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
@@ -128,7 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
engine_prompts = result
|
||||
engine_inputs = result
|
||||
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
|
||||
created_time = int(time.time())
|
||||
@@ -145,11 +144,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
for i, engine_input in enumerate(engine_inputs):
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self._extract_prompt_len(engine_input),
|
||||
self.default_sampling_params,
|
||||
self.override_max_tokens,
|
||||
)
|
||||
@@ -169,7 +168,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -182,7 +181,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
prompt=engine_input,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
@@ -190,7 +189,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
@@ -204,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
num_prompts = len(engine_inputs)
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
@@ -212,7 +211,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if request.stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
engine_inputs,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
@@ -235,8 +234,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = self._extract_prompt_text(engine_prompt)
|
||||
final_res.prompt = self._extract_prompt_text(engine_inputs[i])
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
|
||||
|
||||
@@ -268,7 +266,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[ProcessorInputs],
|
||||
engine_inputs: list[EngineInput],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
@@ -301,8 +299,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
engine_input = engine_inputs[prompt_idx]
|
||||
prompt_text = self._extract_prompt_text(engine_input)
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
|
||||
@@ -72,11 +72,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.inputs import EngineInput, PromptType, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -163,7 +159,7 @@ class ServeContext(Generic[RequestT]):
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[ProcessorInputs] | None = None
|
||||
engine_inputs: list[EngineInput] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
@@ -202,7 +198,7 @@ class OpenAIServing:
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: ProcessorInputs,
|
||||
prompt: EngineInput,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -493,21 +489,21 @@ class OpenAIServing:
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
if ctx.engine_inputs is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
for i, engine_input in enumerate(ctx.engine_inputs):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
@@ -526,10 +522,10 @@ class OpenAIServing:
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Collect batch results from the result generator."""
|
||||
if ctx.engine_prompts is None:
|
||||
if ctx.engine_inputs is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
num_prompts = len(ctx.engine_prompts)
|
||||
num_prompts = len(ctx.engine_inputs)
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
@@ -806,19 +802,19 @@ class OpenAIServing:
|
||||
# Apply server defaults first, then request kwargs override.
|
||||
return default_chat_template_kwargs | request_chat_template_kwargs
|
||||
|
||||
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
|
||||
def _extract_prompt_components(self, prompt: PromptType | EngineInput):
|
||||
return extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
def _extract_prompt_text(self, prompt: ProcessorInputs):
|
||||
def _extract_prompt_text(self, prompt: PromptType | EngineInput):
|
||||
return self._extract_prompt_components(prompt).text
|
||||
|
||||
def _extract_prompt_len(self, prompt: ProcessorInputs):
|
||||
def _extract_prompt_len(self, prompt: EngineInput):
|
||||
return extract_prompt_len(self.model_config, prompt)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | ProcessorInputs,
|
||||
inputs: PromptType | EngineInput,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.engine.protocol import EngineClient, StreamingInput
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsRealtime
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt
|
||||
@@ -83,6 +83,6 @@ class OpenAIServingRealtime(OpenAIServing):
|
||||
|
||||
async for prompt in stream_input_iter:
|
||||
parsed_prompt = parse_model_prompt(model_config, prompt)
|
||||
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
|
||||
(engine_input,) = await renderer.render_cmpl_async([parsed_prompt])
|
||||
|
||||
yield StreamingInput(prompt=engine_prompt)
|
||||
yield StreamingInput(prompt=engine_input)
|
||||
|
||||
@@ -110,7 +110,7 @@ from vllm.entrypoints.openai.responses.utils import (
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import ProcessorInputs, token_inputs
|
||||
from vllm.inputs import EngineInput, tokens_input
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
@@ -269,10 +269,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
def _validate_generator_input(
|
||||
self,
|
||||
engine_prompt: ProcessorInputs,
|
||||
engine_input: EngineInput,
|
||||
) -> ErrorResponse | None:
|
||||
"""Add validations to the input to the generator here."""
|
||||
prompt_len = self._extract_prompt_len(engine_prompt)
|
||||
prompt_len = self._extract_prompt_len(engine_input)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
if prompt_len >= max_model_len:
|
||||
@@ -369,11 +369,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
if self.use_harmony:
|
||||
messages, engine_prompts = self._make_request_with_harmony(
|
||||
messages, engine_inputs = self._make_request_with_harmony(
|
||||
request, prev_response
|
||||
)
|
||||
else:
|
||||
messages, engine_prompts = await self._make_request(request, prev_response)
|
||||
messages, engine_inputs = await self._make_request(request, prev_response)
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request.request_id)
|
||||
if raw_request:
|
||||
@@ -413,15 +413,15 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
available_tools = []
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
for engine_prompt in engine_prompts:
|
||||
maybe_error = self._validate_generator_input(engine_prompt)
|
||||
for engine_input in engine_inputs:
|
||||
maybe_error = self._validate_generator_input(engine_input)
|
||||
if maybe_error is not None:
|
||||
return maybe_error
|
||||
|
||||
default_max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self._extract_prompt_len(engine_input),
|
||||
self.default_sampling_params,
|
||||
self.override_max_tokens,
|
||||
)
|
||||
@@ -480,7 +480,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
generator = self._generate_with_builtin_tools(
|
||||
request_id=request.request_id,
|
||||
engine_prompt=engine_prompt,
|
||||
engine_input=engine_input,
|
||||
sampling_params=sampling_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
@@ -586,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
prev_response_output=prev_response.output if prev_response else None,
|
||||
)
|
||||
|
||||
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
|
||||
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
|
||||
request,
|
||||
messages,
|
||||
default_template=self.chat_template,
|
||||
@@ -595,7 +595,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=self.parser.tool_parser_cls if self.parser else None,
|
||||
)
|
||||
return messages, engine_prompts
|
||||
return messages, engine_inputs
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
@@ -610,7 +610,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
request_input=messages,
|
||||
)
|
||||
|
||||
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
|
||||
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
|
||||
request,
|
||||
new_messages,
|
||||
default_template=chat_template,
|
||||
@@ -619,12 +619,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
)
|
||||
return engine_prompts
|
||||
return engine_inputs
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: ProcessorInputs,
|
||||
engine_input: EngineInput,
|
||||
sampling_params: SamplingParams,
|
||||
context: ConversationContext,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -641,13 +641,13 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
@@ -675,11 +675,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# Render the next prompt token ids and update sampling_params.
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
token_ids = context.render_for_completion()
|
||||
engine_prompt = token_inputs(token_ids)
|
||||
engine_input = tokens_input(token_ids)
|
||||
|
||||
sampling_params.max_tokens = max_model_len - len(token_ids)
|
||||
elif isinstance(context, ParsableContext):
|
||||
(engine_prompt,) = await self._render_next_turn(
|
||||
(engine_input,) = await self._render_next_turn(
|
||||
context.request,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
@@ -691,7 +691,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
context.request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self._extract_prompt_len(engine_input),
|
||||
self.default_sampling_params, # type: ignore
|
||||
self.override_max_tokens, # type: ignore
|
||||
)
|
||||
@@ -713,14 +713,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
arrival_time = time.time()
|
||||
messages = self._construct_input_messages_with_harmony(request, prev_response)
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
engine_prompt = token_inputs(prompt_token_ids)
|
||||
engine_prompt["arrival_time"] = arrival_time
|
||||
engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
|
||||
engine_input["arrival_time"] = arrival_time
|
||||
|
||||
# Add cache_salt if provided in the request
|
||||
if request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
return messages, [engine_prompt]
|
||||
return messages, [engine_input]
|
||||
|
||||
async def _initialize_tool_sessions(
|
||||
self,
|
||||
|
||||
@@ -38,7 +38,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
|
||||
from vllm.inputs import EncoderDecoderInput, EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
@@ -171,7 +171,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
request_id: str,
|
||||
) -> tuple[list[ProcessorInputs], float]:
|
||||
) -> tuple[list[EngineInput], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
@@ -250,9 +250,9 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
parsed_prompts.append(parsed_prompt)
|
||||
|
||||
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
|
||||
engine_inputs = await self.renderer.render_cmpl_async(parsed_prompts)
|
||||
|
||||
return engine_prompts, duration
|
||||
return engine_inputs, duration
|
||||
|
||||
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
@@ -271,7 +271,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int:
|
||||
def _get_decoder_prompt_len(engine_inputs: list[EngineInput]) -> int:
|
||||
"""Get the length of the decoder prompt. Currently we need to offset
|
||||
by the decoder prompt length when running beam search because the mm
|
||||
encoder is not currently cached and runs on decode calls; because of
|
||||
@@ -282,12 +282,13 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
encoder/decoder caching is implemented.
|
||||
"""
|
||||
input_len = 0
|
||||
assert len(engine_prompts) > 0
|
||||
first_eng_prompt = engine_prompts[0]
|
||||
assert len(engine_inputs) > 0
|
||||
first_input = engine_inputs[0]
|
||||
|
||||
if first_input.get("type") == "enc_dec":
|
||||
first_input = cast(EncoderDecoderInput, first_input)
|
||||
input_len = len(first_input["decoder_prompt"]["prompt_token_ids"])
|
||||
|
||||
if first_eng_prompt.get("type") == "enc_dec":
|
||||
first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt)
|
||||
input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"])
|
||||
return input_len
|
||||
|
||||
def _get_verbose_segments(
|
||||
@@ -409,7 +410,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
engine_prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
engine_inputs, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
request_id=request_id,
|
||||
@@ -420,7 +421,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
|
||||
input_len = (
|
||||
OpenAISpeechToText._get_decoder_prompt_len(engine_prompts)
|
||||
OpenAISpeechToText._get_decoder_prompt_len(engine_inputs)
|
||||
if request.use_beam_search
|
||||
else 0
|
||||
)
|
||||
@@ -450,12 +451,12 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
sampling_params.logprobs = 1
|
||||
|
||||
list_result_generator = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
for i, engine_input in enumerate(engine_inputs):
|
||||
request_id_item = f"{request_id}_{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -468,7 +469,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
prompt=engine_input,
|
||||
params=sampling_params,
|
||||
request_id=request_id_item,
|
||||
lora_request=lora_request,
|
||||
@@ -476,7 +477,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
engine_input,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
|
||||
Reference in New Issue
Block a user