Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -368,7 +368,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
return handle
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
def max_sms_used(self) -> int | None:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -86,7 +86,7 @@ def producer(
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
cuda_visible_devices: str | None = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
@@ -120,7 +120,7 @@ def consumer(
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
cuda_visible_devices: str | None = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
@@ -253,7 +253,7 @@ def can_actually_p2p(
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
|
||||
_gpu_p2p_access_cache: dict[str, bool] | None = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@@ -75,7 +74,7 @@ class All2AllManagerBase:
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
def max_sms_used(self) -> int | None:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
|
||||
@@ -96,8 +95,8 @@ class DeviceCommunicatorBase:
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
@@ -123,7 +122,7 @@ class DeviceCommunicatorBase:
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
@@ -156,10 +155,10 @@ class DeviceCommunicatorBase:
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
sizes: list[int] | None = None,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
@@ -196,13 +195,13 @@ class DeviceCommunicatorBase:
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
@@ -231,7 +230,7 @@ class DeviceCommunicatorBase:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
@@ -239,7 +238,7 @@ class DeviceCommunicatorBase:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
|
||||
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -18,8 +18,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
@@ -38,7 +38,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
@@ -99,7 +99,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int,
|
||||
) -> None:
|
||||
return self.dist_module.send_tensor_dict(tensor_dict, dst)
|
||||
@@ -107,7 +107,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
|
||||
@@ -140,16 +140,16 @@ class _CPUSHMDistributed:
|
||||
return handle
|
||||
|
||||
def all_reduce(
|
||||
self, input: torch.Tensor, group: Optional[ProcessGroup] = None
|
||||
self, input: torch.Tensor, group: ProcessGroup | None = None
|
||||
) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
gather_list: list[torch.Tensor] | None,
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(
|
||||
@@ -163,13 +163,13 @@ class _CPUSHMDistributed:
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int,
|
||||
) -> None:
|
||||
key_list = list(tensor_dict.keys())
|
||||
@@ -191,7 +191,7 @@ class _CPUSHMDistributed:
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
|
||||
|
||||
value_list: list[torch.Tensor] = tensor_list[:-1]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -26,8 +25,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
@@ -54,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
)
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
@@ -63,9 +62,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if is_symmetric_memory_enabled():
|
||||
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
self.ca_comm: CustomAllreduce | None = None
|
||||
self.qr_comm: QuickAllReduce | None = None
|
||||
self.symm_mem_comm: SymmMemCommunicator | None = None
|
||||
if use_torch_symm_mem and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
@@ -201,7 +200,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
@@ -235,7 +234,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
@@ -248,7 +247,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
|
||||
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
@@ -274,9 +273,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None,
|
||||
sizes: list[int] | None = None,
|
||||
):
|
||||
if dim != 0:
|
||||
raise NotImplementedError("only dim 0 all-gatherv is supported")
|
||||
@@ -289,7 +288,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if sizes is not None and all(s == sizes[0] for s in sizes):
|
||||
sizes = None
|
||||
|
||||
def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None):
|
||||
def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
|
||||
@@ -7,7 +7,7 @@ convenient for use when we just need to call a few functions.
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
@@ -36,7 +36,7 @@ class Function:
|
||||
argtypes: list[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
def find_loaded_library(lib_name) -> str | None:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
@@ -113,7 +113,7 @@ class CudaRTLibrary:
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
def __init__(self, so_file: str | None = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
if so_file is None:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -55,7 +54,7 @@ class CustomAllreduce:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
device: int | str | torch.device,
|
||||
max_size=8192 * 1024,
|
||||
symm_mem_enabled=False,
|
||||
) -> None:
|
||||
@@ -260,7 +259,7 @@ class CustomAllreduce:
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> torch.Tensor | None:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
@@ -292,8 +291,8 @@ class CustomAllreduce:
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False,
|
||||
group: ProcessGroup | None = None,
|
||||
uncached: bool | None = False,
|
||||
) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
@@ -313,8 +312,8 @@ class CustomAllreduce:
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
rank: int | None = None,
|
||||
) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
@@ -59,9 +58,9 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
class PyNcclCommunicator:
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
group: ProcessGroup | StatelessProcessGroup,
|
||||
device: int | str | torch.device,
|
||||
library_path: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import atexit
|
||||
import contextlib
|
||||
import tempfile
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -141,7 +141,7 @@ class nccl_symm_mem_context:
|
||||
or version.parse(torch.__version__) < version.parse("2.8.0.a0")
|
||||
)
|
||||
if self.disabled:
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = (
|
||||
contextlib.nullcontext()
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
@@ -305,7 +305,7 @@ class NCCLLibrary:
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
def __init__(self, so_file: str | None = None):
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -58,9 +57,7 @@ class QuickAllReduce:
|
||||
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, group: ProcessGroup, device: Union[int, str, torch.device]
|
||||
) -> None:
|
||||
def __init__(self, group: ProcessGroup, device: int | str | torch.device) -> None:
|
||||
"""
|
||||
Custom allreduce provides non-destructive acceleration and is
|
||||
available for CUDA and ROCm MI300 series.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@@ -27,15 +27,15 @@ class RayPPCommunicator(Communicator):
|
||||
This class is not thread-safe.
|
||||
"""
|
||||
|
||||
_comm: Optional[DeviceCommunicatorBase]
|
||||
_comm: DeviceCommunicatorBase | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
comm_id: Any,
|
||||
rank: Optional[int],
|
||||
rank: int | None,
|
||||
actor_handles: list["ray.actor.ActorHandle"],
|
||||
cuda_stream: Optional[torch.cuda.Stream],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
use_communication_streams: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -56,7 +56,7 @@ class RayPPCommunicator(Communicator):
|
||||
This is not supported.
|
||||
"""
|
||||
self._world_size = world_size
|
||||
self._rank: Optional[int] = None
|
||||
self._rank: int | None = None
|
||||
self._actor_handles = actor_handles
|
||||
if use_communication_streams:
|
||||
raise NotImplementedError("use_communication_streams is not supported")
|
||||
@@ -143,7 +143,7 @@ class RayPPCommunicator(Communicator):
|
||||
else:
|
||||
raise ValueError(f"Actor {actor} not found in communicator group")
|
||||
|
||||
def get_self_rank(self) -> Optional[int]:
|
||||
def get_self_rank(self) -> int | None:
|
||||
"""
|
||||
Return this actor's rank.
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from threading import Event
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -80,7 +80,7 @@ class ShmRingBuffer:
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
@@ -213,9 +213,9 @@ class ShmRingBuffer:
|
||||
class Handle:
|
||||
local_reader_ranks: list[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: Optional[tuple[int, int, int, str]] = None
|
||||
local_subscribe_addr: Optional[str] = None
|
||||
remote_subscribe_addr: Optional[str] = None
|
||||
buffer_handle: tuple[int, int, int, str] | None = None
|
||||
local_subscribe_addr: str | None = None
|
||||
remote_subscribe_addr: str | None = None
|
||||
remote_addr_ipv6: bool = False
|
||||
|
||||
|
||||
@@ -224,10 +224,10 @@ class MessageQueue:
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[list[int]] = None,
|
||||
local_reader_ranks: list[int] | None = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
connect_ip: str | None = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
@@ -384,7 +384,7 @@ class MessageQueue:
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: Optional[float] = None):
|
||||
def acquire_write(self, timeout: float | None = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
@@ -444,8 +444,8 @@ class MessageQueue:
|
||||
@contextmanager
|
||||
def acquire_read(
|
||||
self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None,
|
||||
timeout: float | None = None,
|
||||
cancel: Event | None = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
@@ -502,7 +502,7 @@ class MessageQueue:
|
||||
self._read_spin_timer.record_activity()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||
def enqueue(self, obj, timeout: float | None = None):
|
||||
"""Write to message queue with optional timeout (in seconds)"""
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
@@ -520,8 +520,8 @@ class MessageQueue:
|
||||
|
||||
def dequeue(
|
||||
self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None,
|
||||
timeout: float | None = None,
|
||||
cancel: Event | None = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
"""Read from message queue with optional timeout (in seconds)"""
|
||||
@@ -542,7 +542,7 @@ class MessageQueue:
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
|
||||
def recv(socket: zmq.Socket, timeout: float | None) -> Any:
|
||||
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||
if not socket.poll(timeout=timeout_ms):
|
||||
raise TimeoutError
|
||||
@@ -558,7 +558,7 @@ class MessageQueue:
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(
|
||||
pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
pg: ProcessGroup | StatelessProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0,
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from multiprocessing import shared_memory
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -109,7 +109,7 @@ class SingleWriterShmRingBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
data_buffer_size: int,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
create: bool = False,
|
||||
):
|
||||
self.data_buffer_size = data_buffer_size
|
||||
@@ -252,7 +252,7 @@ class SingleWriterShmRingBuffer:
|
||||
def free_buf(
|
||||
self,
|
||||
is_free_fn: Callable[[int, memoryview], bool],
|
||||
nbytes: Optional[int] = None,
|
||||
nbytes: int | None = None,
|
||||
) -> Iterable[int]:
|
||||
"""
|
||||
Free a buffer of the given size. This is a no-op in shared memory,
|
||||
@@ -340,9 +340,7 @@ class MsgpackSerde(ObjectSerde):
|
||||
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
|
||||
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
||||
|
||||
def serialize(
|
||||
self, value: Any
|
||||
) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
|
||||
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
|
||||
len_arr = None
|
||||
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
|
||||
type_name = type(value).__name__
|
||||
@@ -396,7 +394,7 @@ class ShmObjectStorageHandle:
|
||||
n_readers: int
|
||||
ring_buffer_handle: tuple[int, str]
|
||||
serde_class: type[ObjectSerde]
|
||||
reader_lock: Optional[LockType]
|
||||
reader_lock: LockType | None
|
||||
|
||||
|
||||
class SingleWriterShmObjectStorage:
|
||||
@@ -444,7 +442,7 @@ class SingleWriterShmObjectStorage:
|
||||
n_readers: int,
|
||||
ring_buffer: SingleWriterShmRingBuffer,
|
||||
serde_class: type[ObjectSerde] = MsgpackSerde,
|
||||
reader_lock: Optional[LockType] = None,
|
||||
reader_lock: LockType | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the object storage.
|
||||
@@ -492,7 +490,7 @@ class SingleWriterShmObjectStorage:
|
||||
|
||||
def copy_to_buffer(
|
||||
self,
|
||||
data: Union[bytes, list[bytes]],
|
||||
data: bytes | list[bytes],
|
||||
data_bytes: int,
|
||||
metadata: bytes,
|
||||
md_bytes: int,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -31,10 +30,10 @@ class SymmMemCommunicator:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
device: int | str | torch.device,
|
||||
# add options for testing
|
||||
force_multimem: Optional[bool] = None,
|
||||
max_size_override: Optional[int] = None,
|
||||
force_multimem: bool | None = None,
|
||||
max_size_override: int | None = None,
|
||||
):
|
||||
self.disabled = True
|
||||
|
||||
@@ -108,8 +107,8 @@ class SymmMemCommunicator:
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor | None = None
|
||||
) -> torch.Tensor | None:
|
||||
if not self.should_use_symm_mem(inp):
|
||||
return None
|
||||
if out is None:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -39,8 +38,8 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -19,8 +18,8 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
@@ -45,7 +44,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user