Revert "[Core] Rename PromptInputs to PromptType, and inputs to prompt" (#8750)
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages)
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -293,8 +293,8 @@ class LLM:
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
@@ -311,7 +311,7 @@ class LLM:
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
@@ -329,9 +329,7 @@ class LLM:
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
inputs: A list of inputs to generate completions for.
|
||||
sampling_params: The sampling parameters for text generation. If
|
||||
None, we use the default sampling parameters.
|
||||
When it is a single value, it is applied to every prompt.
|
||||
@@ -357,13 +355,12 @@ class LLM:
|
||||
"models (XForCausalLM, XForConditionalGeneration).")
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
inputs = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
|
||||
if isinstance(guided_options_request, dict):
|
||||
if len(guided_options_request) > 1:
|
||||
@@ -378,7 +375,7 @@ class LLM:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
inputs=inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@@ -533,9 +530,9 @@ class LLM:
|
||||
conversation, mm_data = parse_chat_messages(messages, model_config,
|
||||
tokenizer)
|
||||
|
||||
prompt_data: Union[str, List[int]]
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt_data = apply_mistral_chat_template(
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
@@ -543,7 +540,7 @@ class LLM:
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt_data = apply_hf_chat_template(
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
@@ -551,17 +548,17 @@ class LLM:
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
prompt: PromptType
|
||||
if is_list_of(prompt_data, int):
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_data)
|
||||
inputs: PromptInputs
|
||||
if is_list_of(prompt, int):
|
||||
inputs = TokensPrompt(prompt_token_ids=prompt)
|
||||
else:
|
||||
prompt = TextPrompt(prompt=prompt_data)
|
||||
inputs = TextPrompt(prompt=prompt)
|
||||
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
inputs["multi_modal_data"] = mm_data
|
||||
|
||||
return self.generate(
|
||||
prompt,
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
@@ -631,8 +628,8 @@ class LLM:
|
||||
@overload
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
@@ -649,7 +646,7 @@ class LLM:
|
||||
)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
@@ -665,9 +662,9 @@ class LLM:
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
||||
batch inference. See :class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
@@ -690,20 +687,19 @@ class LLM:
|
||||
)
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
inputs = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
inputs=inputs,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
@@ -747,9 +743,9 @@ class LLM:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
"provided.")
|
||||
|
||||
parsed_prompts: List[PromptType] = []
|
||||
inputs: List[PromptInputs] = []
|
||||
for i in range(num_requests):
|
||||
item: PromptType
|
||||
item: PromptInputs
|
||||
|
||||
if prompts is not None:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
@@ -758,24 +754,24 @@ class LLM:
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
parsed_prompts.append(item)
|
||||
inputs.append(item)
|
||||
|
||||
return parsed_prompts
|
||||
return inputs
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
) -> None:
|
||||
if isinstance(prompts, (str, dict)):
|
||||
if isinstance(inputs, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
inputs = [inputs]
|
||||
|
||||
num_requests = len(prompts)
|
||||
num_requests = len(inputs)
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
@@ -792,9 +788,9 @@ class LLM:
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
for i, prompt in enumerate(prompts):
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
self._add_request(
|
||||
prompt,
|
||||
request_inputs,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
@@ -803,7 +799,7 @@ class LLM:
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
@@ -811,7 +807,7 @@ class LLM:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
|
||||
Reference in New Issue
Block a user