rename PromptInputs and inputs with backward compatibility (#8760)

This commit is contained in:
Cyrus Leung
2024-09-26 00:36:47 +08:00
committed by GitHub
parent 0c4d2ad5e6
commit 28e1299e60
21 changed files with 438 additions and 245 deletions

View File

@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -293,8 +293,8 @@ class LLM:
@overload
def generate(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@@ -304,14 +304,13 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
additional_message="Please use the 'prompts' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@@ -330,7 +329,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
@@ -358,12 +359,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
@@ -378,7 +380,7 @@ class LLM:
sampling_params = SamplingParams()
self._validate_and_add_requests(
inputs=inputs,
prompts=parsed_prompts,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -648,8 +650,8 @@ class LLM:
@overload
def encode(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -659,14 +661,13 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
additional_message="Please use the 'prompts' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -682,9 +683,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
@@ -707,19 +708,20 @@ class LLM:
)
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
inputs=inputs,
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -763,9 +765,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
parsed_prompts: List[PromptType] = []
for i in range(num_requests):
item: PromptInputs
item: PromptType
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
@@ -774,13 +776,13 @@ class LLM:
else:
raise AssertionError
inputs.append(item)
parsed_prompts.append(item)
return inputs
return parsed_prompts
def _validate_and_add_requests(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
@@ -788,11 +790,11 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None:
if isinstance(inputs, (str, dict)):
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
prompts = [prompts]
num_requests = len(inputs)
num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
@@ -809,9 +811,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
for i, prompt in enumerate(prompts):
self._add_request(
request_inputs,
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
@@ -821,7 +823,7 @@ class LLM:
def _add_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -830,7 +832,7 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
inputs,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,