[Refactor] Implement output type check in LLM (#34794)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-20 11:57:55 +08:00
committed by GitHub
parent 76df6072ff
commit ac900c89bb
2 changed files with 58 additions and 38 deletions

View File

@@ -10,7 +10,7 @@ import cloudpickle
import torch.nn as nn import torch.nn as nn
from pydantic import ValidationError from pydantic import ValidationError
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing_extensions import TypeVar from typing_extensions import TypeVar, overload
from vllm.beam_search import ( from vllm.beam_search import (
BeamSearchInstance, BeamSearchInstance,
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_O = TypeVar(
"_O",
bound=RequestOutput | PoolingRequestOutput,
default=RequestOutput | PoolingRequestOutput,
)
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None) _P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
@@ -447,17 +452,16 @@ class LLM:
if sampling_params is None: if sampling_params is None:
sampling_params = self.get_default_sampling_params() sampling_params = self.get_default_sampling_params()
outputs = self._run_completion( return self._run_completion(
prompts=prompts, prompts=prompts,
params=sampling_params, params=sampling_params,
output_type=RequestOutput,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
) )
return self.engine_class.validate_outputs(outputs, RequestOutput)
def enqueue( def enqueue(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
@@ -524,23 +528,43 @@ class LLM:
return request_ids return request_ids
@overload
def wait_for_completion( def wait_for_completion(
self, self,
*,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput]: ) -> list[RequestOutput | PoolingRequestOutput]: ...
@overload
def wait_for_completion(
self,
output_type: type[_O] | tuple[type[_O], ...],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[_O]: ...
def wait_for_completion(
self,
output_type: type[Any] | tuple[type[Any], ...] | None = None,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[Any]:
"""Wait for all enqueued requests to complete and return results. """Wait for all enqueued requests to complete and return results.
This method processes all requests currently in the engine queue This method processes all requests currently in the engine queue
and returns their outputs. Use after enqueue() to get results. and returns their outputs. Use after enqueue() to get results.
Args: Args:
output_type: The expected output type, defaults to RequestOutput.
use_tqdm: If True, shows a tqdm progress bar. use_tqdm: If True, shows a tqdm progress bar.
Returns: Returns:
A list of RequestOutput objects for all completed requests. A list of output objects for all completed requests.
""" """
outputs = self._run_engine(use_tqdm=use_tqdm) if output_type is None:
return self.engine_class.validate_outputs(outputs, RequestOutput) output_type = (RequestOutput, PoolingRequestOutput)
return self._run_engine(output_type, use_tqdm=use_tqdm)
def _resolve_mm_lora( def _resolve_mm_lora(
self, self,
@@ -744,13 +768,13 @@ class LLM:
# only runs for one step # only runs for one step
# we don't need to use tqdm here # we don't need to use tqdm here
raw_output = self._render_and_run_requests( output = self._render_and_run_requests(
prompts=(beam.get_prompt() for beam in all_beams), prompts=(beam.get_prompt() for beam in all_beams),
params=self._params_to_seq(sampling_params, len(all_beams)), params=self._params_to_seq(sampling_params, len(all_beams)),
output_type=RequestOutput,
lora_requests=[beam.lora_request for beam in all_beams], lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False, use_tqdm=False,
) )
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
for (start, end), instance in zip( for (start, end), instance in zip(
instance_start_and_end, instances_batch instance_start_and_end, instances_batch
@@ -987,9 +1011,10 @@ class LLM:
if sampling_params is None: if sampling_params is None:
sampling_params = self.get_default_sampling_params() sampling_params = self.get_default_sampling_params()
outputs = self._run_chat( return self._run_chat(
messages=messages, messages=messages,
params=sampling_params, params=sampling_params,
output_type=RequestOutput,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
chat_template=chat_template, chat_template=chat_template,
@@ -1002,8 +1027,6 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
return self.engine_class.validate_outputs(outputs, RequestOutput)
def encode( def encode(
self, self,
prompts: PromptType | Sequence[PromptType] | DataPrompt, prompts: PromptType | Sequence[PromptType] | DataPrompt,
@@ -1135,19 +1158,16 @@ class LLM:
outputs = self._run_completion( outputs = self._run_completion(
prompts=prompts_seq, prompts=prompts_seq,
params=params_seq, params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput
)
if use_io_processor: if use_io_processor:
# get the post-processed model outputs # get the post-processed model outputs
assert self.io_processor is not None assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(model_outputs) processed_outputs = self.io_processor.post_process(outputs)
return [ return [
PoolingRequestOutput[Any]( PoolingRequestOutput[Any](
@@ -1160,8 +1180,8 @@ class LLM:
finished=True, finished=True,
) )
] ]
else:
return model_outputs return outputs
def embed( def embed(
self, self,
@@ -1353,8 +1373,7 @@ class LLM:
embed_2=encoded_output_2, embed_2=encoded_output_2,
) )
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in scores]
return [ScoringRequestOutput.from_base(item) for item in items]
def _late_interaction_score( def _late_interaction_score(
self, self,
@@ -1393,7 +1412,7 @@ class LLM:
) )
text_2.append(text) text_2.append(text)
encoded_output: list[PoolingRequestOutput] = self.encode( encoded_output = self.encode(
text_1 + text_2, text_1 + text_2,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
@@ -1402,8 +1421,8 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] encoded_output_1 = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] encoded_output_2 = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1: if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2) encoded_output_1 = encoded_output_1 * len(encoded_output_2)
@@ -1434,8 +1453,7 @@ class LLM:
) )
) )
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in scores]
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score( def _cross_encoding_score(
self, self,
@@ -1491,13 +1509,12 @@ class LLM:
outputs = self._run_completion( outputs = self._run_completion(
prompts=prompts, prompts=prompts,
params=pooling_params_list, params=pooling_params_list,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
) )
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in outputs]
return [ScoringRequestOutput.from_base(item) for item in items]
def score( def score(
self, self,
@@ -1759,6 +1776,7 @@ class LLM:
params: SamplingParams params: SamplingParams
| PoolingParams | PoolingParams
| Sequence[SamplingParams | PoolingParams], | Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
@@ -1790,6 +1808,7 @@ class LLM:
) )
), ),
params=seq_params, params=seq_params,
output_type=output_type,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_requests=seq_lora_requests, lora_requests=seq_lora_requests,
priorities=seq_priority, priorities=seq_priority,
@@ -1802,6 +1821,7 @@ class LLM:
params: SamplingParams params: SamplingParams
| PoolingParams | PoolingParams
| Sequence[SamplingParams | PoolingParams], | Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None = None, lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
@@ -1848,6 +1868,7 @@ class LLM:
) )
), ),
params=seq_params, params=seq_params,
output_type=output_type,
lora_requests=seq_lora_requests, lora_requests=seq_lora_requests,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
) )
@@ -1856,6 +1877,7 @@ class LLM:
self, self,
prompts: Iterable[ProcessorInputs], prompts: Iterable[ProcessorInputs],
params: Sequence[SamplingParams | PoolingParams], params: Sequence[SamplingParams | PoolingParams],
output_type: type[_O],
*, *,
lora_requests: Sequence[LoRARequest | None] | None = None, lora_requests: Sequence[LoRARequest | None] | None = None,
priorities: Sequence[int] | None = None, priorities: Sequence[int] | None = None,
@@ -1878,7 +1900,7 @@ class LLM:
priorities=priorities, priorities=priorities,
) )
return self._run_engine(use_tqdm=use_tqdm) return self._run_engine(output_type, use_tqdm=use_tqdm)
def _render_and_add_requests( def _render_and_add_requests(
self, self,
@@ -1932,9 +1954,10 @@ class LLM:
def _run_engine( def _run_engine(
self, self,
output_type: type[_O] | tuple[type[_O], ...],
*, *,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput | PoolingRequestOutput]: ) -> list[_O]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests() num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -1947,14 +1970,15 @@ class LLM:
) )
# Run the engine. # Run the engine.
outputs: list[RequestOutput | PoolingRequestOutput] = [] outputs: list[_O] = []
total_in_toks = 0 total_in_toks = 0
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
assert isinstance(output, output_type)
if output.finished: if output.finished:
outputs.append(output) outputs.append(output) # type: ignore[arg-type]
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput # Calculate tokens only for RequestOutput

View File

@@ -199,10 +199,6 @@ class LLMEngine:
self.should_execute_dummy_batch = True self.should_execute_dummy_batch = True
return aggregated_has_unfinished return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
if not hasattr(self, "_supported_tasks"): if not hasattr(self, "_supported_tasks"):
# Cache the result # Cache the result