[Refactor] Implement output type check in LLM (#34794)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,7 +10,7 @@ import cloudpickle
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar, overload
|
||||||
|
|
||||||
from vllm.beam_search import (
|
from vllm.beam_search import (
|
||||||
BeamSearchInstance,
|
BeamSearchInstance,
|
||||||
@@ -94,6 +94,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_O = TypeVar(
|
||||||
|
"_O",
|
||||||
|
bound=RequestOutput | PoolingRequestOutput,
|
||||||
|
default=RequestOutput | PoolingRequestOutput,
|
||||||
|
)
|
||||||
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
|
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
|
||||||
_R = TypeVar("_R", default=Any)
|
_R = TypeVar("_R", default=Any)
|
||||||
|
|
||||||
@@ -447,17 +452,16 @@ class LLM:
|
|||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = self.get_default_sampling_params()
|
sampling_params = self.get_default_sampling_params()
|
||||||
|
|
||||||
outputs = self._run_completion(
|
return self._run_completion(
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
|
output_type=RequestOutput,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
|
||||||
|
|
||||||
def enqueue(
|
def enqueue(
|
||||||
self,
|
self,
|
||||||
prompts: PromptType | Sequence[PromptType],
|
prompts: PromptType | Sequence[PromptType],
|
||||||
@@ -524,23 +528,43 @@ class LLM:
|
|||||||
|
|
||||||
return request_ids
|
return request_ids
|
||||||
|
|
||||||
|
@overload
|
||||||
def wait_for_completion(
|
def wait_for_completion(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
) -> list[RequestOutput]:
|
) -> list[RequestOutput | PoolingRequestOutput]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def wait_for_completion(
|
||||||
|
self,
|
||||||
|
output_type: type[_O] | tuple[type[_O], ...],
|
||||||
|
*,
|
||||||
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
|
) -> list[_O]: ...
|
||||||
|
|
||||||
|
def wait_for_completion(
|
||||||
|
self,
|
||||||
|
output_type: type[Any] | tuple[type[Any], ...] | None = None,
|
||||||
|
*,
|
||||||
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
|
) -> list[Any]:
|
||||||
"""Wait for all enqueued requests to complete and return results.
|
"""Wait for all enqueued requests to complete and return results.
|
||||||
|
|
||||||
This method processes all requests currently in the engine queue
|
This method processes all requests currently in the engine queue
|
||||||
and returns their outputs. Use after enqueue() to get results.
|
and returns their outputs. Use after enqueue() to get results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
output_type: The expected output type, defaults to RequestOutput.
|
||||||
use_tqdm: If True, shows a tqdm progress bar.
|
use_tqdm: If True, shows a tqdm progress bar.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of RequestOutput objects for all completed requests.
|
A list of output objects for all completed requests.
|
||||||
"""
|
"""
|
||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
if output_type is None:
|
||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
output_type = (RequestOutput, PoolingRequestOutput)
|
||||||
|
|
||||||
|
return self._run_engine(output_type, use_tqdm=use_tqdm)
|
||||||
|
|
||||||
def _resolve_mm_lora(
|
def _resolve_mm_lora(
|
||||||
self,
|
self,
|
||||||
@@ -744,13 +768,13 @@ class LLM:
|
|||||||
|
|
||||||
# only runs for one step
|
# only runs for one step
|
||||||
# we don't need to use tqdm here
|
# we don't need to use tqdm here
|
||||||
raw_output = self._render_and_run_requests(
|
output = self._render_and_run_requests(
|
||||||
prompts=(beam.get_prompt() for beam in all_beams),
|
prompts=(beam.get_prompt() for beam in all_beams),
|
||||||
params=self._params_to_seq(sampling_params, len(all_beams)),
|
params=self._params_to_seq(sampling_params, len(all_beams)),
|
||||||
|
output_type=RequestOutput,
|
||||||
lora_requests=[beam.lora_request for beam in all_beams],
|
lora_requests=[beam.lora_request for beam in all_beams],
|
||||||
use_tqdm=False,
|
use_tqdm=False,
|
||||||
)
|
)
|
||||||
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
|
|
||||||
|
|
||||||
for (start, end), instance in zip(
|
for (start, end), instance in zip(
|
||||||
instance_start_and_end, instances_batch
|
instance_start_and_end, instances_batch
|
||||||
@@ -987,9 +1011,10 @@ class LLM:
|
|||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = self.get_default_sampling_params()
|
sampling_params = self.get_default_sampling_params()
|
||||||
|
|
||||||
outputs = self._run_chat(
|
return self._run_chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
|
output_type=RequestOutput,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
@@ -1002,8 +1027,6 @@ class LLM:
|
|||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: PromptType | Sequence[PromptType] | DataPrompt,
|
prompts: PromptType | Sequence[PromptType] | DataPrompt,
|
||||||
@@ -1135,19 +1158,16 @@ class LLM:
|
|||||||
outputs = self._run_completion(
|
outputs = self._run_completion(
|
||||||
prompts=prompts_seq,
|
prompts=prompts_seq,
|
||||||
params=params_seq,
|
params=params_seq,
|
||||||
|
output_type=PoolingRequestOutput,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_outputs = self.engine_class.validate_outputs(
|
|
||||||
outputs, PoolingRequestOutput
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_io_processor:
|
if use_io_processor:
|
||||||
# get the post-processed model outputs
|
# get the post-processed model outputs
|
||||||
assert self.io_processor is not None
|
assert self.io_processor is not None
|
||||||
processed_outputs = self.io_processor.post_process(model_outputs)
|
processed_outputs = self.io_processor.post_process(outputs)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PoolingRequestOutput[Any](
|
PoolingRequestOutput[Any](
|
||||||
@@ -1160,8 +1180,8 @@ class LLM:
|
|||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return model_outputs
|
return outputs
|
||||||
|
|
||||||
def embed(
|
def embed(
|
||||||
self,
|
self,
|
||||||
@@ -1353,8 +1373,7 @@ class LLM:
|
|||||||
embed_2=encoded_output_2,
|
embed_2=encoded_output_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
return [ScoringRequestOutput.from_base(item) for item in scores]
|
||||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
|
||||||
|
|
||||||
def _late_interaction_score(
|
def _late_interaction_score(
|
||||||
self,
|
self,
|
||||||
@@ -1393,7 +1412,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
text_2.append(text)
|
text_2.append(text)
|
||||||
|
|
||||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
encoded_output = self.encode(
|
||||||
text_1 + text_2,
|
text_1 + text_2,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@@ -1402,8 +1421,8 @@ class LLM:
|
|||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
encoded_output_1 = encoded_output[0 : len(text_1)]
|
||||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
|
encoded_output_2 = encoded_output[len(text_1) :]
|
||||||
|
|
||||||
if len(encoded_output_1) == 1:
|
if len(encoded_output_1) == 1:
|
||||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||||
@@ -1434,8 +1453,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
return [ScoringRequestOutput.from_base(item) for item in scores]
|
||||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
|
||||||
|
|
||||||
def _cross_encoding_score(
|
def _cross_encoding_score(
|
||||||
self,
|
self,
|
||||||
@@ -1491,13 +1509,12 @@ class LLM:
|
|||||||
outputs = self._run_completion(
|
outputs = self._run_completion(
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
params=pooling_params_list,
|
params=pooling_params_list,
|
||||||
|
output_type=PoolingRequestOutput,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
|
return [ScoringRequestOutput.from_base(item) for item in outputs]
|
||||||
|
|
||||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
|
||||||
|
|
||||||
def score(
|
def score(
|
||||||
self,
|
self,
|
||||||
@@ -1759,6 +1776,7 @@ class LLM:
|
|||||||
params: SamplingParams
|
params: SamplingParams
|
||||||
| PoolingParams
|
| PoolingParams
|
||||||
| Sequence[SamplingParams | PoolingParams],
|
| Sequence[SamplingParams | PoolingParams],
|
||||||
|
output_type: type[_O],
|
||||||
*,
|
*,
|
||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||||
@@ -1790,6 +1808,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
params=seq_params,
|
params=seq_params,
|
||||||
|
output_type=output_type,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_requests=seq_lora_requests,
|
lora_requests=seq_lora_requests,
|
||||||
priorities=seq_priority,
|
priorities=seq_priority,
|
||||||
@@ -1802,6 +1821,7 @@ class LLM:
|
|||||||
params: SamplingParams
|
params: SamplingParams
|
||||||
| PoolingParams
|
| PoolingParams
|
||||||
| Sequence[SamplingParams | PoolingParams],
|
| Sequence[SamplingParams | PoolingParams],
|
||||||
|
output_type: type[_O],
|
||||||
*,
|
*,
|
||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||||
@@ -1848,6 +1868,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
params=seq_params,
|
params=seq_params,
|
||||||
|
output_type=output_type,
|
||||||
lora_requests=seq_lora_requests,
|
lora_requests=seq_lora_requests,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
)
|
)
|
||||||
@@ -1856,6 +1877,7 @@ class LLM:
|
|||||||
self,
|
self,
|
||||||
prompts: Iterable[ProcessorInputs],
|
prompts: Iterable[ProcessorInputs],
|
||||||
params: Sequence[SamplingParams | PoolingParams],
|
params: Sequence[SamplingParams | PoolingParams],
|
||||||
|
output_type: type[_O],
|
||||||
*,
|
*,
|
||||||
lora_requests: Sequence[LoRARequest | None] | None = None,
|
lora_requests: Sequence[LoRARequest | None] | None = None,
|
||||||
priorities: Sequence[int] | None = None,
|
priorities: Sequence[int] | None = None,
|
||||||
@@ -1878,7 +1900,7 @@ class LLM:
|
|||||||
priorities=priorities,
|
priorities=priorities,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._run_engine(use_tqdm=use_tqdm)
|
return self._run_engine(output_type, use_tqdm=use_tqdm)
|
||||||
|
|
||||||
def _render_and_add_requests(
|
def _render_and_add_requests(
|
||||||
self,
|
self,
|
||||||
@@ -1932,9 +1954,10 @@ class LLM:
|
|||||||
|
|
||||||
def _run_engine(
|
def _run_engine(
|
||||||
self,
|
self,
|
||||||
|
output_type: type[_O] | tuple[type[_O], ...],
|
||||||
*,
|
*,
|
||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
) -> list[RequestOutput | PoolingRequestOutput]:
|
) -> list[_O]:
|
||||||
# Initialize tqdm.
|
# Initialize tqdm.
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||||
@@ -1947,14 +1970,15 @@ class LLM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run the engine.
|
# Run the engine.
|
||||||
outputs: list[RequestOutput | PoolingRequestOutput] = []
|
outputs: list[_O] = []
|
||||||
total_in_toks = 0
|
total_in_toks = 0
|
||||||
total_out_toks = 0
|
total_out_toks = 0
|
||||||
while self.llm_engine.has_unfinished_requests():
|
while self.llm_engine.has_unfinished_requests():
|
||||||
step_outputs = self.llm_engine.step()
|
step_outputs = self.llm_engine.step()
|
||||||
for output in step_outputs:
|
for output in step_outputs:
|
||||||
|
assert isinstance(output, output_type)
|
||||||
if output.finished:
|
if output.finished:
|
||||||
outputs.append(output)
|
outputs.append(output) # type: ignore[arg-type]
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
if isinstance(output, RequestOutput):
|
if isinstance(output, RequestOutput):
|
||||||
# Calculate tokens only for RequestOutput
|
# Calculate tokens only for RequestOutput
|
||||||
|
|||||||
@@ -199,10 +199,6 @@ class LLMEngine:
|
|||||||
self.should_execute_dummy_batch = True
|
self.should_execute_dummy_batch = True
|
||||||
return aggregated_has_unfinished
|
return aggregated_has_unfinished
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_outputs(cls, outputs, output_type):
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
if not hasattr(self, "_supported_tasks"):
|
if not hasattr(self, "_supported_tasks"):
|
||||||
# Cache the result
|
# Cache the result
|
||||||
|
|||||||
Reference in New Issue
Block a user