Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
@@ -17,7 +18,7 @@ from multiprocessing.connection import Connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
@@ -59,8 +60,8 @@ class MultiprocExecutor(Executor):
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: Optional[FailureCallback] = None
|
||||
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
self.io_thread_pool: ThreadPoolExecutor | None = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
@@ -179,7 +180,7 @@ class MultiprocExecutor(Executor):
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
non_block: bool = False,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
(output,) = self.collective_rpc(
|
||||
@@ -207,7 +208,7 @@ class MultiprocExecutor(Executor):
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc(
|
||||
"take_draft_token_ids", unique_reply_rank=self.output_rank
|
||||
@@ -216,12 +217,12 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
def collective_rpc(
|
||||
self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
method: str | Callable,
|
||||
timeout: float | None = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
kwargs: dict | None = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: Optional[int] = None,
|
||||
unique_reply_rank: int | None = None,
|
||||
) -> list[Any]:
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
@@ -252,8 +253,8 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
def get_response(
|
||||
w: WorkerProcHandle,
|
||||
dequeue_timeout: Optional[float] = None,
|
||||
cancel_event: Optional[threading.Event] = None,
|
||||
dequeue_timeout: float | None = None,
|
||||
cancel_event: threading.Event | None = None,
|
||||
):
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=cancel_event
|
||||
@@ -370,7 +371,7 @@ class UnreadyWorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_pipe: Connection
|
||||
death_writer: Optional[Connection] = None
|
||||
death_writer: Connection | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -378,7 +379,7 @@ class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
death_writer: Optional[Connection] = None
|
||||
death_writer: Connection | None = None
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
@@ -505,7 +506,7 @@ class WorkerProc:
|
||||
)
|
||||
|
||||
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
|
||||
ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len(
|
||||
ready_proc_handles: list[WorkerProcHandle | None] = [None] * len(
|
||||
unready_proc_handles
|
||||
)
|
||||
while pipes:
|
||||
@@ -674,7 +675,7 @@ class WorkerProc:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
||||
def worker_busy_loop(self, cancel: threading.Event | None = None):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||
|
||||
Reference in New Issue
Block a user