diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9d1e2912c..f1b32c750 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,7 +10,7 @@ import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar +from typing_extensions import TypeVar, overload from vllm.beam_search import ( BeamSearchInstance, @@ -94,6 +94,11 @@ if TYPE_CHECKING: logger = init_logger(__name__) +_O = TypeVar( + "_O", + bound=RequestOutput | PoolingRequestOutput, + default=RequestOutput | PoolingRequestOutput, +) _P = TypeVar("_P", bound=SamplingParams | PoolingParams | None) _R = TypeVar("_R", default=Any) @@ -447,17 +452,16 @@ class LLM: if sampling_params is None: sampling_params = self.get_default_sampling_params() - outputs = self._run_completion( + return self._run_completion( prompts=prompts, params=sampling_params, + output_type=RequestOutput, use_tqdm=use_tqdm, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, priority=priority, ) - return self.engine_class.validate_outputs(outputs, RequestOutput) - def enqueue( self, prompts: PromptType | Sequence[PromptType], @@ -524,23 +528,43 @@ class LLM: return request_ids + @overload def wait_for_completion( self, + *, use_tqdm: bool | Callable[..., tqdm] = True, - ) -> list[RequestOutput]: + ) -> list[RequestOutput | PoolingRequestOutput]: ... + + @overload + def wait_for_completion( + self, + output_type: type[_O] | tuple[type[_O], ...], + *, + use_tqdm: bool | Callable[..., tqdm] = True, + ) -> list[_O]: ... + + def wait_for_completion( + self, + output_type: type[Any] | tuple[type[Any], ...] | None = None, + *, + use_tqdm: bool | Callable[..., tqdm] = True, + ) -> list[Any]: """Wait for all enqueued requests to complete and return results. This method processes all requests currently in the engine queue and returns their outputs. Use after enqueue() to get results. Args: + output_type: The expected output type, defaults to RequestOutput. use_tqdm: If True, shows a tqdm progress bar. Returns: - A list of RequestOutput objects for all completed requests. + A list of output objects for all completed requests. """ - outputs = self._run_engine(use_tqdm=use_tqdm) - return self.engine_class.validate_outputs(outputs, RequestOutput) + if output_type is None: + output_type = (RequestOutput, PoolingRequestOutput) + + return self._run_engine(output_type, use_tqdm=use_tqdm) def _resolve_mm_lora( self, @@ -744,13 +768,13 @@ class LLM: # only runs for one step # we don't need to use tqdm here - raw_output = self._render_and_run_requests( + output = self._render_and_run_requests( prompts=(beam.get_prompt() for beam in all_beams), params=self._params_to_seq(sampling_params, len(all_beams)), + output_type=RequestOutput, lora_requests=[beam.lora_request for beam in all_beams], use_tqdm=False, ) - output = self.engine_class.validate_outputs(raw_output, RequestOutput) for (start, end), instance in zip( instance_start_and_end, instances_batch @@ -987,9 +1011,10 @@ class LLM: if sampling_params is None: sampling_params = self.get_default_sampling_params() - outputs = self._run_chat( + return self._run_chat( messages=messages, params=sampling_params, + output_type=RequestOutput, use_tqdm=use_tqdm, lora_request=lora_request, chat_template=chat_template, @@ -1002,8 +1027,6 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, ) - return self.engine_class.validate_outputs(outputs, RequestOutput) - def encode( self, prompts: PromptType | Sequence[PromptType] | DataPrompt, @@ -1135,19 +1158,16 @@ class LLM: outputs = self._run_completion( prompts=prompts_seq, params=params_seq, + output_type=PoolingRequestOutput, use_tqdm=use_tqdm, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) - model_outputs = self.engine_class.validate_outputs( - outputs, PoolingRequestOutput - ) - if use_io_processor: # get the post-processed model outputs assert self.io_processor is not None - processed_outputs = self.io_processor.post_process(model_outputs) + processed_outputs = self.io_processor.post_process(outputs) return [ PoolingRequestOutput[Any]( @@ -1160,8 +1180,8 @@ class LLM: finished=True, ) ] - else: - return model_outputs + + return outputs def embed( self, @@ -1353,8 +1373,7 @@ class LLM: embed_2=encoded_output_2, ) - items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) - return [ScoringRequestOutput.from_base(item) for item in items] + return [ScoringRequestOutput.from_base(item) for item in scores] def _late_interaction_score( self, @@ -1393,7 +1412,7 @@ class LLM: ) text_2.append(text) - encoded_output: list[PoolingRequestOutput] = self.encode( + encoded_output = self.encode( text_1 + text_2, use_tqdm=use_tqdm, lora_request=lora_request, @@ -1402,8 +1421,8 @@ class LLM: tokenization_kwargs=tokenization_kwargs, ) - encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] - encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] + encoded_output_1 = encoded_output[0 : len(text_1)] + encoded_output_2 = encoded_output[len(text_1) :] if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) @@ -1434,8 +1453,7 @@ class LLM: ) ) - items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) - return [ScoringRequestOutput.from_base(item) for item in items] + return [ScoringRequestOutput.from_base(item) for item in scores] def _cross_encoding_score( self, @@ -1491,13 +1509,12 @@ class LLM: outputs = self._run_completion( prompts=prompts, params=pooling_params_list, + output_type=PoolingRequestOutput, use_tqdm=use_tqdm, lora_request=lora_request, ) - items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput) - - return [ScoringRequestOutput.from_base(item) for item in items] + return [ScoringRequestOutput.from_base(item) for item in outputs] def score( self, @@ -1759,6 +1776,7 @@ class LLM: params: SamplingParams | PoolingParams | Sequence[SamplingParams | PoolingParams], + output_type: type[_O], *, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: Sequence[LoRARequest] | LoRARequest | None = None, @@ -1790,6 +1808,7 @@ class LLM: ) ), params=seq_params, + output_type=output_type, use_tqdm=use_tqdm, lora_requests=seq_lora_requests, priorities=seq_priority, @@ -1802,6 +1821,7 @@ class LLM: params: SamplingParams | PoolingParams | Sequence[SamplingParams | PoolingParams], + output_type: type[_O], *, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: Sequence[LoRARequest] | LoRARequest | None = None, @@ -1848,6 +1868,7 @@ class LLM: ) ), params=seq_params, + output_type=output_type, lora_requests=seq_lora_requests, use_tqdm=use_tqdm, ) @@ -1856,6 +1877,7 @@ class LLM: self, prompts: Iterable[ProcessorInputs], params: Sequence[SamplingParams | PoolingParams], + output_type: type[_O], *, lora_requests: Sequence[LoRARequest | None] | None = None, priorities: Sequence[int] | None = None, @@ -1878,7 +1900,7 @@ class LLM: priorities=priorities, ) - return self._run_engine(use_tqdm=use_tqdm) + return self._run_engine(output_type, use_tqdm=use_tqdm) def _render_and_add_requests( self, @@ -1932,9 +1954,10 @@ class LLM: def _run_engine( self, + output_type: type[_O] | tuple[type[_O], ...], *, use_tqdm: bool | Callable[..., tqdm] = True, - ) -> list[RequestOutput | PoolingRequestOutput]: + ) -> list[_O]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -1947,14 +1970,15 @@ class LLM: ) # Run the engine. - outputs: list[RequestOutput | PoolingRequestOutput] = [] + outputs: list[_O] = [] total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: + assert isinstance(output, output_type) if output.finished: - outputs.append(output) + outputs.append(output) # type: ignore[arg-type] if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c4f0442f3..6a8df0dc7 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -199,10 +199,6 @@ class LLMEngine: self.should_execute_dummy_batch = True return aggregated_has_unfinished - @classmethod - def validate_outputs(cls, outputs, output_type): - return outputs - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: if not hasattr(self, "_supported_tasks"): # Cache the result