rename PromptInputs and inputs with backward compatibility (#8760)
This commit is contained in:
@@ -2,8 +2,8 @@ import asyncio
|
||||
import time
|
||||
import weakref
|
||||
from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
|
||||
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
|
||||
from weakref import ReferenceType
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import weak_bind
|
||||
from vllm.utils import deprecate_kwargs, weak_bind
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
@overload # DEPRECATED
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
*,
|
||||
inputs: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"inputs",
|
||||
additional_message="Please use the 'prompt' parameter instead.",
|
||||
)
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[PromptType] = None,
|
||||
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
"""Async version of :meth:`add_request`."""
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
@@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time = time.time()
|
||||
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
inputs,
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@@ -774,16 +811,55 @@ class AsyncLLMEngine:
|
||||
|
||||
# This method does not need to be async, but kept that way
|
||||
# for backwards compatibility.
|
||||
async def add_request(
|
||||
@overload # DEPRECATED
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
*,
|
||||
inputs: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"inputs",
|
||||
additional_message="Please use the 'prompt' parameter instead.",
|
||||
)
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[PromptType] = None,
|
||||
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
@@ -797,7 +873,7 @@ class AsyncLLMEngine:
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
verbose=self.log_requests,
|
||||
inputs=inputs,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time or time.time(),
|
||||
lora_request=lora_request,
|
||||
@@ -808,7 +884,7 @@ class AsyncLLMEngine:
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -822,8 +898,7 @@ class AsyncLLMEngine:
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each input.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
@@ -881,7 +956,7 @@ class AsyncLLMEngine:
|
||||
"""
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
@@ -891,7 +966,7 @@ class AsyncLLMEngine:
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -904,8 +979,7 @@ class AsyncLLMEngine:
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
@@ -959,7 +1033,7 @@ class AsyncLLMEngine:
|
||||
"""
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
|
||||
Reference in New Issue
Block a user