[Core] renamePromptInputs and inputs (#8876)
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages)
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -293,8 +293,8 @@ class LLM:
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
@@ -304,14 +304,13 @@ class LLM:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"prompts",
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter instead.",
|
||||
additional_message="Please use the 'prompts' parameter instead.",
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
@@ -330,7 +329,9 @@ class LLM:
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
inputs: A list of inputs to generate completions for.
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
sampling_params: The sampling parameters for text generation. If
|
||||
None, we use the default sampling parameters.
|
||||
When it is a single value, it is applied to every prompt.
|
||||
@@ -358,12 +359,13 @@ class LLM:
|
||||
"models (XForCausalLM, XForConditionalGeneration).")
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if isinstance(guided_options_request, dict):
|
||||
if len(guided_options_request) > 1:
|
||||
@@ -378,7 +380,7 @@ class LLM:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
inputs=inputs,
|
||||
prompts=parsed_prompts,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@@ -648,8 +650,8 @@ class LLM:
|
||||
@overload
|
||||
def encode(
|
||||
self,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
@@ -659,14 +661,13 @@ class LLM:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"prompts",
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter instead.",
|
||||
additional_message="Please use the 'prompts' parameter instead.",
|
||||
)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
@@ -682,9 +683,9 @@ class LLM:
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
||||
batch inference. See :class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
@@ -707,19 +708,20 @@ class LLM:
|
||||
)
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
inputs=inputs,
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@@ -763,9 +765,9 @@ class LLM:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
"provided.")
|
||||
|
||||
inputs: List[PromptInputs] = []
|
||||
parsed_prompts: List[PromptType] = []
|
||||
for i in range(num_requests):
|
||||
item: PromptInputs
|
||||
item: PromptType
|
||||
|
||||
if prompts is not None:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
@@ -774,13 +776,13 @@ class LLM:
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
inputs.append(item)
|
||||
parsed_prompts.append(item)
|
||||
|
||||
return inputs
|
||||
return parsed_prompts
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
@@ -788,11 +790,11 @@ class LLM:
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
if isinstance(inputs, (str, dict)):
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
inputs = [inputs]
|
||||
prompts = [prompts]
|
||||
|
||||
num_requests = len(inputs)
|
||||
num_requests = len(prompts)
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
@@ -809,9 +811,9 @@ class LLM:
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
for i, prompt in enumerate(prompts):
|
||||
self._add_request(
|
||||
request_inputs,
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
@@ -821,7 +823,7 @@ class LLM:
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -830,7 +832,7 @@ class LLM:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
|
||||
Reference in New Issue
Block a user