[Core] Factor out input preprocessing to a separate class (#7329)
This commit is contained in:
@@ -4,22 +4,17 @@ from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents, SchedulerOutputState)
|
||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@@ -403,139 +398,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group(
|
||||
missing_msg="prompts must be None if skip_tokenizer_init is True")
|
||||
|
||||
return await tokenizer.encode_async(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
prompt = None
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
else:
|
||||
# NOTE: This extra assignment is required to pass mypy
|
||||
prompt = parsed_prompt = inputs["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
parsed_prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
else:
|
||||
assert_never(inputs)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_task = self._extract_prompt_components_async(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
"""Async version of :meth:`process_model_inputs`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
model_inputs = await self._process_encoder_decoder_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
# Decoder-only operation
|
||||
model_inputs = await self._process_decoder_only_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
return self.input_processor(model_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -553,12 +415,13 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
|
||||
Reference in New Issue
Block a user