Enable mypy checking on V1 code (#11105)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2024-12-14 17:54:04 +00:00
committed by GitHub
parent 93abf23a64
commit 6d917d0eeb
21 changed files with 160 additions and 121 deletions

View File

@@ -5,7 +5,7 @@ import threading
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from typing import List, Tuple, Type, Union
from typing import List, Tuple, Type
import zmq
import zmq.asyncio
@@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
@@ -97,8 +97,10 @@ class EngineCore:
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))
req = Request.from_engine_core_request(request)
@@ -128,7 +130,7 @@ class EngineCore:
def shutdown(self):
self.model_executor.shutdown()
def profile(self, is_start=True):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
@@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
@@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
self._last_logging_time = now
def _handle_client_request(
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
if isinstance(request, EngineCoreRequest):