[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, overload
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar
|
||||
@@ -44,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
@@ -188,7 +188,7 @@ class LLMEngine:
|
||||
raise TypeError(f"Expected output of type {output_type}, "
|
||||
f"but found type {type(output)}")
|
||||
|
||||
return output
|
||||
return cast(_O, output)
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(
|
||||
@@ -1039,6 +1039,7 @@ class LLMEngine:
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
has_multiple_outputs: bool = len(outputs) > 1
|
||||
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
|
||||
if has_multiple_outputs:
|
||||
assert self.scheduler_config.is_multi_step or \
|
||||
self.speculative_config
|
||||
@@ -1084,6 +1085,7 @@ class LLMEngine:
|
||||
finished_before.append(i)
|
||||
continue
|
||||
|
||||
output: List[SequenceGroupOutput]
|
||||
if has_multiple_outputs:
|
||||
output = outputs_by_sequence_group[i]
|
||||
else:
|
||||
@@ -1096,7 +1098,7 @@ class LLMEngine:
|
||||
seq_group, seq_group_meta, is_first_step_output)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size)
|
||||
seq_group_meta.token_chunk_size or 0)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@@ -1104,13 +1106,13 @@ class LLMEngine:
|
||||
and seq_group.metrics is not None):
|
||||
if seq_group.metrics.model_forward_time is not None:
|
||||
seq_group.metrics.model_forward_time += (
|
||||
o.model_forward_time)
|
||||
o.model_forward_time or 0)
|
||||
else:
|
||||
seq_group.metrics.model_forward_time = (
|
||||
o.model_forward_time)
|
||||
if seq_group.metrics.model_execute_time is not None:
|
||||
seq_group.metrics.model_execute_time += (
|
||||
o.model_execute_time)
|
||||
o.model_execute_time or 0)
|
||||
else:
|
||||
seq_group.metrics.model_execute_time = (
|
||||
o.model_execute_time)
|
||||
@@ -1236,8 +1238,10 @@ class LLMEngine:
|
||||
seq_group, seq_group_metadata,
|
||||
seq_group.state.num_steps == 1)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
token_chunk_size = (seq_group_metadata.token_chunk_size
|
||||
if seq_group_metadata.token_chunk_size
|
||||
is not None else 0)
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
|
||||
Reference in New Issue
Block a user