[Refactor] Implement output type check in LLM (#34794)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-20 11:57:55 +08:00
committed by GitHub
parent 76df6072ff
commit ac900c89bb
2 changed files with 58 additions and 38 deletions

View File

@@ -10,7 +10,7 @@ import cloudpickle
import torch.nn as nn
from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar
from typing_extensions import TypeVar, overload
from vllm.beam_search import (
BeamSearchInstance,
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_O = TypeVar(
"_O",
bound=RequestOutput | PoolingRequestOutput,
default=RequestOutput | PoolingRequestOutput,
)
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
_R = TypeVar("_R", default=Any)
@@ -447,17 +452,16 @@ class LLM:
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
outputs = self._run_completion(
return self._run_completion(
prompts=prompts,
params=sampling_params,
output_type=RequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def enqueue(
self,
prompts: PromptType | Sequence[PromptType],
@@ -524,23 +528,43 @@ class LLM:
return request_ids
@overload
def wait_for_completion(
self,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput]:
) -> list[RequestOutput | PoolingRequestOutput]: ...
@overload
def wait_for_completion(
self,
output_type: type[_O] | tuple[type[_O], ...],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[_O]: ...
def wait_for_completion(
self,
output_type: type[Any] | tuple[type[Any], ...] | None = None,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[Any]:
"""Wait for all enqueued requests to complete and return results.
This method processes all requests currently in the engine queue
and returns their outputs. Use after enqueue() to get results.
Args:
output_type: The expected output type, defaults to RequestOutput.
use_tqdm: If True, shows a tqdm progress bar.
Returns:
A list of RequestOutput objects for all completed requests.
A list of output objects for all completed requests.
"""
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
if output_type is None:
output_type = (RequestOutput, PoolingRequestOutput)
return self._run_engine(output_type, use_tqdm=use_tqdm)
def _resolve_mm_lora(
self,
@@ -744,13 +768,13 @@ class LLM:
# only runs for one step
# we don't need to use tqdm here
raw_output = self._render_and_run_requests(
output = self._render_and_run_requests(
prompts=(beam.get_prompt() for beam in all_beams),
params=self._params_to_seq(sampling_params, len(all_beams)),
output_type=RequestOutput,
lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False,
)
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
for (start, end), instance in zip(
instance_start_and_end, instances_batch
@@ -987,9 +1011,10 @@ class LLM:
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
outputs = self._run_chat(
return self._run_chat(
messages=messages,
params=sampling_params,
output_type=RequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
chat_template=chat_template,
@@ -1002,8 +1027,6 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs,
)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def encode(
self,
prompts: PromptType | Sequence[PromptType] | DataPrompt,
@@ -1135,19 +1158,16 @@ class LLM:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput
)
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(model_outputs)
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
@@ -1160,8 +1180,8 @@ class LLM:
finished=True,
)
]
else:
return model_outputs
return outputs
def embed(
self,
@@ -1353,8 +1373,7 @@ class LLM:
embed_2=encoded_output_2,
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
return [ScoringRequestOutput.from_base(item) for item in scores]
def _late_interaction_score(
self,
@@ -1393,7 +1412,7 @@ class LLM:
)
text_2.append(text)
encoded_output: list[PoolingRequestOutput] = self.encode(
encoded_output = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
@@ -1402,8 +1421,8 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
encoded_output_1 = encoded_output[0 : len(text_1)]
encoded_output_2 = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
@@ -1434,8 +1453,7 @@ class LLM:
)
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
return [ScoringRequestOutput.from_base(item) for item in scores]
def _cross_encoding_score(
self,
@@ -1491,13 +1509,12 @@ class LLM:
outputs = self._run_completion(
prompts=prompts,
params=pooling_params_list,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
return [ScoringRequestOutput.from_base(item) for item in outputs]
def score(
self,
@@ -1759,6 +1776,7 @@ class LLM:
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
@@ -1790,6 +1808,7 @@ class LLM:
)
),
params=seq_params,
output_type=output_type,
use_tqdm=use_tqdm,
lora_requests=seq_lora_requests,
priorities=seq_priority,
@@ -1802,6 +1821,7 @@ class LLM:
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
@@ -1848,6 +1868,7 @@ class LLM:
)
),
params=seq_params,
output_type=output_type,
lora_requests=seq_lora_requests,
use_tqdm=use_tqdm,
)
@@ -1856,6 +1877,7 @@ class LLM:
self,
prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*,
lora_requests: Sequence[LoRARequest | None] | None = None,
priorities: Sequence[int] | None = None,
@@ -1878,7 +1900,7 @@ class LLM:
priorities=priorities,
)
return self._run_engine(use_tqdm=use_tqdm)
return self._run_engine(output_type, use_tqdm=use_tqdm)
def _render_and_add_requests(
self,
@@ -1932,9 +1954,10 @@ class LLM:
def _run_engine(
self,
output_type: type[_O] | tuple[type[_O], ...],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput | PoolingRequestOutput]:
) -> list[_O]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -1947,14 +1970,15 @@ class LLM:
)
# Run the engine.
outputs: list[RequestOutput | PoolingRequestOutput] = []
outputs: list[_O] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
assert isinstance(output, output_type)
if output.finished:
outputs.append(output)
outputs.append(output) # type: ignore[arg-type]
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput

View File

@@ -199,10 +199,6 @@ class LLMEngine:
self.should_execute_dummy_batch = True
return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
if not hasattr(self, "_supported_tasks"):
# Cache the result