Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -160,7 +160,7 @@ class DeviceCommunicatorBase:
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -126,7 +126,7 @@ class _CPUSHMDistributed:
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[List[torch.Tensor]],
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -154,7 +154,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_impl is not None
|
||||
hidden_states, router_logits = self.all2all_impl.dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
@@ -6,7 +6,7 @@ convenient for use when we just need to call a few functions.
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
@@ -32,7 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
@@ -97,11 +97,11 @@ class CudaRTLibrary:
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -276,7 +276,7 @@ class CustomAllreduce:
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> List[int]:
|
||||
uncached: Optional[bool] = False) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
@@ -284,7 +284,7 @@ class CustomAllreduce:
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
pointers: list[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer) # type: ignore
|
||||
@@ -293,7 +293,7 @@ class CustomAllreduce:
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: List[int],
|
||||
def free_shared_buffer(pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
|
||||
@@ -7,8 +7,9 @@ import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
@@ -149,7 +150,7 @@ def can_actually_p2p(
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: List[bool] = []
|
||||
result: list[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
@@ -175,7 +176,7 @@ def can_actually_p2p(
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
@@ -204,7 +205,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: Dict[str, bool] = {}
|
||||
cache: dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
@@ -121,7 +121,7 @@ class ncclRedOpTypeEnum:
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
@@ -210,11 +210,11 @@ class NCCLLibrary:
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
path_to_library_cache: dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
@@ -238,7 +238,7 @@ class NCCLLibrary:
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs: Dict[str, Any] = {}
|
||||
_funcs: dict[str, Any] = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
|
||||
@@ -8,7 +8,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from threading import Event
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -173,9 +173,9 @@ class ShmRingBuffer:
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
local_reader_ranks: list[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
||||
buffer_handle: Optional[tuple[int, int, int, str]] = None
|
||||
local_subscribe_addr: Optional[str] = None
|
||||
remote_subscribe_addr: Optional[str] = None
|
||||
remote_addr_ipv6: bool = False
|
||||
@@ -187,7 +187,7 @@ class MessageQueue:
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[List[int]] = None,
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
|
||||
Reference in New Issue
Block a user