[Core] Support serving encoder/decoder models (#7258)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents)
|
||||
from vllm.engine.metrics import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def process_model_inputs_async(
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return await tokenizer.encode_async(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
prompt = None
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
else:
|
||||
# NOTE: This extra assignment is required to pass mypy
|
||||
prompt = parsed_prompt = inputs["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
parsed_prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
else:
|
||||
assert_never(inputs)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_task = self._extract_prompt_components_async(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
prompt_token_ids = await tokenizer.encode_async(
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
"""Async version of :meth:`process_model_inputs`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
model_inputs = await self._process_encoder_decoder_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = [
|
||||
0
|
||||
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
|
||||
prompt_token_ids
|
||||
# Decoder-only operation
|
||||
model_inputs = await self._process_decoder_only_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
return self.input_processor(llm_inputs)
|
||||
return self.input_processor(model_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
"""Async version of :meth:`add_request`."""
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
|
||||
Reference in New Issue
Block a user