[Core] Consolidate prompt arguments to LLM engines (#4328)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-05-29 04:29:31 +08:00
committed by GitHub
parent 290f4ada2b
commit 5ae5ed1e60
43 changed files with 1407 additions and 442 deletions

View File

@@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
@@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs
async def encode_request_async(
async def process_model_inputs_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
prompt=inputs["prompt"],
lora_request=lora_request)
return prompt_token_ids
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(request_id,
prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
async def check_health_async(self) -> None:
self.model_executor.check_health()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
This class is used to wrap the :class:`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The :class:`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
@@ -315,8 +321,8 @@ class AsyncLLMEngine:
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
@@ -526,22 +532,26 @@ class AsyncLLMEngine:
async def add_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
@@ -562,39 +572,33 @@ class AsyncLLMEngine:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
processed_inputs = await self.engine.process_model_inputs_async \
.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
inputs=inputs,
lora_request=lora_request)
else:
prompt_token_ids = await self.engine.encode_request_async(
processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
inputs=inputs,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
inputs=processed_inputs,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
async def generate(
self,
prompt: Optional[str],
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@@ -603,14 +607,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine
@@ -659,24 +661,20 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
sampling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
lora_request=lora_request,
):
yield output
yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
prompt: Optional[str],
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
@@ -685,14 +683,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
@@ -739,24 +735,21 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
lora_request=lora_request,
):
yield output
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def process_request(
async def _process_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
*,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
@@ -764,12 +757,10 @@ class AsyncLLMEngine:
stream = await self.add_request(
request_id,
prompt,
inputs,
params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
try: