[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return request_outputs
|
||||
|
||||
async def encode_request_async(
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = await self.tokenizer.encode_async(
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
prompt_token_ids = await tokenizer.encode_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = await self.encode_request_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
return self.add_request(request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
request_id=request_id, inputs=inputs, lora_request=lora_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
"""An asynchronous wrapper for :class:`LLMEngine`.
|
||||
|
||||
This class is used to wrap the LLMEngine class to make it asynchronous. It
|
||||
uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there
|
||||
are requests in the waiting queue. The generate method yields the outputs
|
||||
from the LLMEngine to the caller.
|
||||
This class is used to wrap the :class:`LLMEngine` class to make it
|
||||
asynchronous. It uses asyncio to create a background loop that keeps
|
||||
processing incoming requests. The :class:`LLMEngine` is kicked by the
|
||||
generate method when there are requests in the waiting queue. The generate
|
||||
method yields the outputs from the :class:`LLMEngine` to the caller.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
|
||||
NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
@@ -315,8 +321,8 @@ class AsyncLLMEngine:
|
||||
being printed in log.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args: Arguments for LLMEngine.
|
||||
*kwargs: Arguments for LLMEngine.
|
||||
*args: Arguments for :class:`LLMEngine`.
|
||||
**kwargs: Arguments for :class:`LLMEngine`.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
@@ -526,22 +532,26 @@ class AsyncLLMEngine:
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
shortened_token_ids = prompt_token_ids
|
||||
if self.max_log_len is not None:
|
||||
if isinstance(inputs, str):
|
||||
shortened_prompt = inputs
|
||||
shortened_token_ids = None
|
||||
else:
|
||||
shortened_prompt = inputs.get("prompt")
|
||||
shortened_token_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||
shortened_prompt = shortened_prompt[:max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:self.
|
||||
max_log_len]
|
||||
shortened_token_ids = shortened_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
@@ -562,39 +572,33 @@ class AsyncLLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.engine_use_ray:
|
||||
prompt_token_ids = await (
|
||||
self.engine.encode_request_async.remote( # type: ignore
|
||||
processed_inputs = await self.engine.process_model_inputs_async \
|
||||
.remote( # type: ignore
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request))
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
else:
|
||||
prompt_token_ids = await self.engine.encode_request_async(
|
||||
processed_inputs = await self.engine.process_model_inputs_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
inputs=processed_inputs,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@@ -603,14 +607,12 @@ class AsyncLLMEngine:
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
@@ -659,24 +661,20 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self.process_request(
|
||||
async for output in self._process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
lora_request=lora_request,
|
||||
):
|
||||
yield output
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
@@ -685,14 +683,12 @@ class AsyncLLMEngine:
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
@@ -739,24 +735,21 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self.process_request(
|
||||
async for output in self._process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
pooling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
lora_request=lora_request,
|
||||
):
|
||||
yield output
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
|
||||
async def process_request(
|
||||
async def _process_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
@@ -764,12 +757,10 @@ class AsyncLLMEngine:
|
||||
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user