Revert "rename PromptInputs and inputs with backward compatibility (#8760) (#8810)

This commit is contained in:
Simon Mo
2024-09-25 10:36:26 -07:00
committed by GitHub
parent 873edda6cf
commit 4f1ba0844b
21 changed files with 245 additions and 438 deletions

View File

@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -293,8 +293,8 @@ class LLM:
@overload
def generate(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@@ -304,13 +304,14 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
additional_message="Please use the 'inputs' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@@ -329,9 +330,7 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
@@ -359,13 +358,12 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
@@ -380,7 +378,7 @@ class LLM:
sampling_params = SamplingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -650,8 +648,8 @@ class LLM:
@overload
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -661,13 +659,14 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
additional_message="Please use the 'inputs' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -683,9 +682,9 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
@@ -708,20 +707,19 @@ class LLM:
)
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -765,9 +763,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
parsed_prompts: List[PromptType] = []
inputs: List[PromptInputs] = []
for i in range(num_requests):
item: PromptType
item: PromptInputs
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
@@ -776,13 +774,13 @@ class LLM:
else:
raise AssertionError
parsed_prompts.append(item)
inputs.append(item)
return parsed_prompts
return inputs
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
@@ -790,11 +788,11 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None:
if isinstance(prompts, (str, dict)):
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
inputs = [inputs]
num_requests = len(prompts)
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
@@ -811,9 +809,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, prompt in enumerate(prompts):
for i, request_inputs in enumerate(inputs):
self._add_request(
prompt,
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
@@ -823,7 +821,7 @@ class LLM:
def _add_request(
self,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -832,7 +830,7 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,