Enable mypy checking on V1 code (#11105)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2024-12-14 17:54:04 +00:00
committed by GitHub
parent 93abf23a64
commit 6d917d0eeb
21 changed files with 160 additions and 121 deletions

View File

@@ -7,7 +7,7 @@ import time
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing.process import BaseProcess
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import zmq
@@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_open_port,
get_open_zmq_ipc_path)
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import make_zmq_socket
from vllm.worker.worker_base import WorkerWrapperBase
@@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class MultiprocExecutor:
class MultiprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
# Call self.shutdown at exit to clean up
@@ -103,7 +104,7 @@ class MultiprocExecutor:
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Execute an RPC call on workers.
@@ -125,7 +126,7 @@ class MultiprocExecutor:
responses = [None] * self.world_size
for w in self.workers:
dequeue_timeout = timeout - (time.monotonic() - start_time()
dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout)
@@ -153,7 +154,7 @@ class MultiprocExecutor:
args=(scheduler_output, ))[0]
return model_output
def profile(self, is_start=True):
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
return
@@ -185,7 +186,6 @@ class MultiprocExecutor:
p.kill()
self._cleanup_sockets()
self.workers = None
def _cleanup_sockets(self):
for w in self.workers:
@@ -200,7 +200,8 @@ class MultiprocExecutor:
# again
atexit.unregister(self.shutdown)
"""Properly shut down the executor and its workers"""
if (hasattr(self, 'workers') and self.workers is not None):
if getattr(self, 'shutting_down', False):
self.shutting_down = True
for w in self.workers: #TODO: not sure if needed
w.worker_response_mq = None
self._ensure_worker_termination()