Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-14 12:07:57 +01:00
committed by GitHub
parent 9b5b39b650
commit dc372b9c8a
23 changed files with 105 additions and 105 deletions

View File

@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.distributed as dist
@@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import List, Optional
from typing import Optional
import torch
from torch.distributed import ProcessGroup
@@ -126,7 +126,7 @@ class _CPUSHMDistributed:
def gather(self,
input: torch.Tensor,
gather_list: Optional[List[torch.Tensor]],
gather_list: Optional[list[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank.

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
from torch.distributed import ProcessGroup
@@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
hidden_states, router_logits)

View File

@@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
@@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
class Function:
name: str
restype: Any
argtypes: List[Any]
argtypes: list[Any]
def find_loaded_library(lib_name) -> Optional[str]:
@@ -97,11 +97,11 @@ class CudaRTLibrary:
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import List, Optional, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@@ -276,7 +276,7 @@ class CustomAllreduce:
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> List[int]:
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
@@ -284,7 +284,7 @@ class CustomAllreduce:
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
@@ -293,7 +293,7 @@ class CustomAllreduce:
return pointers
@staticmethod
def free_shared_buffer(pointers: List[int],
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:

View File

@@ -7,8 +7,9 @@ import pickle
import subprocess
import sys
import tempfile
from collections.abc import Sequence
from itertools import product
from typing import Dict, List, Optional, Sequence
from typing import Optional
import torch.distributed as dist
import torch.multiprocessing as mp
@@ -149,7 +150,7 @@ def can_actually_p2p(
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = []
result: list[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
@@ -175,7 +176,7 @@ def can_actually_p2p(
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
@@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {}
cache: dict[str, bool] = {}
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))

View File

@@ -24,7 +24,7 @@
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.distributed import ReduceOp
@@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
class Function:
name: str
restype: Any
argtypes: List[Any]
argtypes: list[Any]
class NCCLLibrary:
@@ -210,11 +210,11 @@ class NCCLLibrary:
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
path_to_library_cache: dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
@@ -238,7 +238,7 @@ class NCCLLibrary:
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
_funcs: dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype

View File

@@ -8,7 +8,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from threading import Event
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
from unittest.mock import patch
import torch
@@ -173,9 +173,9 @@ class ShmRingBuffer:
@dataclass
class Handle:
local_reader_ranks: List[int] = field(default_factory=list)
local_reader_ranks: list[int] = field(default_factory=list)
buffer_handle: Optional[Tuple[int, int, int, str]] = None
buffer_handle: Optional[tuple[int, int, int, str]] = None
local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False
@@ -187,7 +187,7 @@ class MessageQueue:
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None,
local_reader_ranks: Optional[list[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,