[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant
2024-10-16 18:55:59 -04:00
committed by GitHub
parent 8345045833
commit 776dbd74f1
20 changed files with 109 additions and 74 deletions

View File

@@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, overload
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar
@@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@@ -188,7 +188,7 @@ class LLMEngine:
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
return cast(_O, output)
@classmethod
def validate_outputs(
@@ -1039,6 +1039,7 @@ class LLMEngine:
scheduler_outputs.scheduled_seq_groups)
has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
@@ -1084,6 +1085,7 @@ class LLMEngine:
finished_before.append(i)
continue
output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
@@ -1096,7 +1098,7 @@ class LLMEngine:
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
seq_group_meta.token_chunk_size or 0)
if outputs:
for o in outputs:
@@ -1104,13 +1106,13 @@ class LLMEngine:
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
o.model_forward_time or 0)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
o.model_execute_time or 0)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
@@ -1236,8 +1238,10 @@ class LLMEngine:
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
token_chunk_size = (seq_group_metadata.token_chunk_size
if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (