Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -9,6 +9,7 @@ import threading
import time
import traceback
import weakref
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
@@ -17,7 +18,7 @@ from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from multiprocessing.synchronize import Lock as LockType
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
from typing import Any, cast
import cloudpickle
import torch
@@ -59,8 +60,8 @@ class MultiprocExecutor(Executor):
self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
self.failure_callback: FailureCallback | None = None
self.io_thread_pool: ThreadPoolExecutor | None = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@@ -179,7 +180,7 @@ class MultiprocExecutor(Executor):
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output,) = self.collective_rpc(
@@ -207,7 +208,7 @@ class MultiprocExecutor(Executor):
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
def take_draft_token_ids(self) -> DraftTokenIds | None:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc(
"take_draft_token_ids", unique_reply_rank=self.output_rank
@@ -216,12 +217,12 @@ class MultiprocExecutor(Executor):
def collective_rpc(
self,
method: Union[str, Callable],
timeout: Optional[float] = None,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: Optional[dict] = None,
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None,
unique_reply_rank: int | None = None,
) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
@@ -252,8 +253,8 @@ class MultiprocExecutor(Executor):
def get_response(
w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None,
dequeue_timeout: float | None = None,
cancel_event: threading.Event | None = None,
):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event
@@ -370,7 +371,7 @@ class UnreadyWorkerProcHandle:
proc: BaseProcess
rank: int
ready_pipe: Connection
death_writer: Optional[Connection] = None
death_writer: Connection | None = None
@dataclass
@@ -378,7 +379,7 @@ class WorkerProcHandle:
proc: BaseProcess
rank: int
worker_response_mq: MessageQueue # The worker process writes to this MQ
death_writer: Optional[Connection] = None
death_writer: Connection | None = None
@classmethod
def from_unready_handle(
@@ -505,7 +506,7 @@ class WorkerProc:
)
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len(
ready_proc_handles: list[WorkerProcHandle | None] = [None] * len(
unready_proc_handles
)
while pipes:
@@ -674,7 +675,7 @@ class WorkerProc:
output = self.async_output_queue.get()
self.enqueue_output(output)
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
def worker_busy_loop(self, cancel: threading.Event | None = None):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(