[Refactor] Implement output type check in LLM (#34794)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,7 +10,7 @@ import cloudpickle
|
||||
import torch.nn as nn
|
||||
from pydantic import ValidationError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar, overload
|
||||
|
||||
from vllm.beam_search import (
|
||||
BeamSearchInstance,
|
||||
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_O = TypeVar(
|
||||
"_O",
|
||||
bound=RequestOutput | PoolingRequestOutput,
|
||||
default=RequestOutput | PoolingRequestOutput,
|
||||
)
|
||||
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
@@ -447,17 +452,16 @@ class LLM:
|
||||
if sampling_params is None:
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
outputs = self._run_completion(
|
||||
return self._run_completion(
|
||||
prompts=prompts,
|
||||
params=sampling_params,
|
||||
output_type=RequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
@@ -524,23 +528,43 @@ class LLM:
|
||||
|
||||
return request_ids
|
||||
|
||||
@overload
|
||||
def wait_for_completion(
|
||||
self,
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
) -> list[RequestOutput]:
|
||||
) -> list[RequestOutput | PoolingRequestOutput]: ...
|
||||
|
||||
@overload
|
||||
def wait_for_completion(
|
||||
self,
|
||||
output_type: type[_O] | tuple[type[_O], ...],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
) -> list[_O]: ...
|
||||
|
||||
def wait_for_completion(
|
||||
self,
|
||||
output_type: type[Any] | tuple[type[Any], ...] | None = None,
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
) -> list[Any]:
|
||||
"""Wait for all enqueued requests to complete and return results.
|
||||
|
||||
This method processes all requests currently in the engine queue
|
||||
and returns their outputs. Use after enqueue() to get results.
|
||||
|
||||
Args:
|
||||
output_type: The expected output type, defaults to RequestOutput.
|
||||
use_tqdm: If True, shows a tqdm progress bar.
|
||||
|
||||
Returns:
|
||||
A list of RequestOutput objects for all completed requests.
|
||||
A list of output objects for all completed requests.
|
||||
"""
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
if output_type is None:
|
||||
output_type = (RequestOutput, PoolingRequestOutput)
|
||||
|
||||
return self._run_engine(output_type, use_tqdm=use_tqdm)
|
||||
|
||||
def _resolve_mm_lora(
|
||||
self,
|
||||
@@ -744,13 +768,13 @@ class LLM:
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
raw_output = self._render_and_run_requests(
|
||||
output = self._render_and_run_requests(
|
||||
prompts=(beam.get_prompt() for beam in all_beams),
|
||||
params=self._params_to_seq(sampling_params, len(all_beams)),
|
||||
output_type=RequestOutput,
|
||||
lora_requests=[beam.lora_request for beam in all_beams],
|
||||
use_tqdm=False,
|
||||
)
|
||||
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
|
||||
|
||||
for (start, end), instance in zip(
|
||||
instance_start_and_end, instances_batch
|
||||
@@ -987,9 +1011,10 @@ class LLM:
|
||||
if sampling_params is None:
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
outputs = self._run_chat(
|
||||
return self._run_chat(
|
||||
messages=messages,
|
||||
params=sampling_params,
|
||||
output_type=RequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
chat_template=chat_template,
|
||||
@@ -1002,8 +1027,6 @@ class LLM:
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType] | DataPrompt,
|
||||
@@ -1135,19 +1158,16 @@ class LLM:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
model_outputs = self.engine_class.validate_outputs(
|
||||
outputs, PoolingRequestOutput
|
||||
)
|
||||
|
||||
if use_io_processor:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(model_outputs)
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
@@ -1160,8 +1180,8 @@ class LLM:
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
else:
|
||||
return model_outputs
|
||||
|
||||
return outputs
|
||||
|
||||
def embed(
|
||||
self,
|
||||
@@ -1353,8 +1373,7 @@ class LLM:
|
||||
embed_2=encoded_output_2,
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
return [ScoringRequestOutput.from_base(item) for item in scores]
|
||||
|
||||
def _late_interaction_score(
|
||||
self,
|
||||
@@ -1393,7 +1412,7 @@ class LLM:
|
||||
)
|
||||
text_2.append(text)
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
encoded_output = self.encode(
|
||||
text_1 + text_2,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
@@ -1402,8 +1421,8 @@ class LLM:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
|
||||
encoded_output_1 = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2 = encoded_output[len(text_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
@@ -1434,8 +1453,7 @@ class LLM:
|
||||
)
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
return [ScoringRequestOutput.from_base(item) for item in scores]
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
@@ -1491,13 +1509,12 @@ class LLM:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts,
|
||||
params=pooling_params_list,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
return [ScoringRequestOutput.from_base(item) for item in outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
@@ -1759,6 +1776,7 @@ class LLM:
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
output_type: type[_O],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
@@ -1790,6 +1808,7 @@ class LLM:
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
output_type=output_type,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
@@ -1802,6 +1821,7 @@ class LLM:
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
output_type: type[_O],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
@@ -1848,6 +1868,7 @@ class LLM:
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
output_type=output_type,
|
||||
lora_requests=seq_lora_requests,
|
||||
use_tqdm=use_tqdm,
|
||||
)
|
||||
@@ -1856,6 +1877,7 @@ class LLM:
|
||||
self,
|
||||
prompts: Iterable[ProcessorInputs],
|
||||
params: Sequence[SamplingParams | PoolingParams],
|
||||
output_type: type[_O],
|
||||
*,
|
||||
lora_requests: Sequence[LoRARequest | None] | None = None,
|
||||
priorities: Sequence[int] | None = None,
|
||||
@@ -1878,7 +1900,7 @@ class LLM:
|
||||
priorities=priorities,
|
||||
)
|
||||
|
||||
return self._run_engine(use_tqdm=use_tqdm)
|
||||
return self._run_engine(output_type, use_tqdm=use_tqdm)
|
||||
|
||||
def _render_and_add_requests(
|
||||
self,
|
||||
@@ -1932,9 +1954,10 @@ class LLM:
|
||||
|
||||
def _run_engine(
|
||||
self,
|
||||
output_type: type[_O] | tuple[type[_O], ...],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
) -> list[RequestOutput | PoolingRequestOutput]:
|
||||
) -> list[_O]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
@@ -1947,14 +1970,15 @@ class LLM:
|
||||
)
|
||||
|
||||
# Run the engine.
|
||||
outputs: list[RequestOutput | PoolingRequestOutput] = []
|
||||
outputs: list[_O] = []
|
||||
total_in_toks = 0
|
||||
total_out_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
assert isinstance(output, output_type)
|
||||
if output.finished:
|
||||
outputs.append(output)
|
||||
outputs.append(output) # type: ignore[arg-type]
|
||||
if use_tqdm:
|
||||
if isinstance(output, RequestOutput):
|
||||
# Calculate tokens only for RequestOutput
|
||||
|
||||
@@ -199,10 +199,6 @@ class LLMEngine:
|
||||
self.should_execute_dummy_batch = True
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
if not hasattr(self, "_supported_tasks"):
|
||||
# Cache the result
|
||||
|
||||
Reference in New Issue
Block a user