Revert "rename PromptInputs and inputs with backward compatibility (#8760) (#8810)

This commit is contained in:
Simon Mo
2024-09-25 10:36:26 -07:00
committed by GitHub
parent 873edda6cf
commit 4f1ba0844b
21 changed files with 245 additions and 438 deletions

View File

@@ -2,8 +2,8 @@ import asyncio
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType
import vllm.envs as envs
@@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind
from vllm.utils import weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -402,54 +402,17 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
@overload # DEPRECATED
async def add_request_async(
self,
request_id: str,
*,
inputs: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@overload
async def add_request_async(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
@@ -457,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -811,55 +774,16 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# for backwards compatibility.
@overload # DEPRECATED
def add_request(
self,
request_id: str,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@overload
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@@ -873,7 +797,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
prompt=prompt,
inputs=inputs,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
@@ -884,7 +808,7 @@ class AsyncLLMEngine:
async def generate(
self,
prompt: PromptType,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@@ -898,7 +822,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
@@ -956,7 +881,7 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
prompt,
inputs,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
@@ -966,7 +891,7 @@ class AsyncLLMEngine:
async def encode(
self,
prompt: PromptType,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@@ -979,7 +904,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
@@ -1033,7 +959,7 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
prompt,
inputs,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,