Revert "[Core] Rename PromptInputs to PromptType, and inputs to prompt" (#8750)

This commit is contained in:
Simon Mo
2024-09-23 22:45:20 -07:00
committed by GitHub
parent 0250dd68c5
commit 3185fb0cca
18 changed files with 162 additions and 157 deletions

View File

@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -293,8 +293,8 @@ class LLM:
@overload
def generate(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@@ -311,7 +311,7 @@ class LLM:
)
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@@ -329,9 +329,7 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
@@ -357,13 +355,12 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
@@ -378,7 +375,7 @@ class LLM:
sampling_params = SamplingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -533,9 +530,9 @@ class LLM:
conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)
prompt_data: Union[str, List[int]]
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
@@ -543,7 +540,7 @@ class LLM:
tools=tools,
)
else:
prompt_data = apply_hf_chat_template(
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
@@ -551,17 +548,17 @@ class LLM:
tools=tools,
)
prompt: PromptType
if is_list_of(prompt_data, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data)
inputs: PromptInputs
if is_list_of(prompt, int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
prompt = TextPrompt(prompt=prompt_data)
inputs = TextPrompt(prompt=prompt)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
inputs["multi_modal_data"] = mm_data
return self.generate(
prompt,
inputs,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
@@ -631,8 +628,8 @@ class LLM:
@overload
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -649,7 +646,7 @@ class LLM:
)
def encode(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@@ -665,9 +662,9 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
@@ -690,20 +687,19 @@ class LLM:
)
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@@ -747,9 +743,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
parsed_prompts: List[PromptType] = []
inputs: List[PromptInputs] = []
for i in range(num_requests):
item: PromptType
item: PromptInputs
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
@@ -758,24 +754,24 @@ class LLM:
else:
raise AssertionError
parsed_prompts.append(item)
inputs.append(item)
return parsed_prompts
return inputs
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
) -> None:
if isinstance(prompts, (str, dict)):
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
inputs = [inputs]
num_requests = len(prompts)
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
@@ -792,9 +788,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, prompt in enumerate(prompts):
for i, request_inputs in enumerate(inputs):
self._add_request(
prompt,
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
@@ -803,7 +799,7 @@ class LLM:
def _add_request(
self,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -811,7 +807,7 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,