[Core] Support serving encoder/decoder models (#7258)

This commit is contained in:
Cyrus Leung
2024-08-09 10:39:41 +08:00
committed by GitHub
parent 0fa14907da
commit 7eb4a51c5f
25 changed files with 603 additions and 464 deletions

View File

@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents)
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async(
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
prompt_token_ids = await tokenizer.encode_async(
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
)
else:
prompt_token_ids = inputs["prompt_token_ids"]
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
return self.input_processor(model_inputs)
async def add_request_async(
self,
@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
inputs,
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
)
self._add_processed_request(
request_id=request_id,