Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any
import torch
import torch.distributed as dist
@@ -368,7 +368,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
return handle
# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
def max_sms_used(self) -> int | None:
return 0

View File

@@ -10,7 +10,7 @@ import sys
import tempfile
from collections.abc import Sequence
from itertools import product
from typing import Any, Optional
from typing import Any
import torch
import torch.distributed as dist
@@ -86,7 +86,7 @@ def producer(
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
cuda_visible_devices: str | None = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
@@ -120,7 +120,7 @@ def consumer(
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
cuda_visible_devices: str | None = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
@@ -253,7 +253,7 @@ def can_actually_p2p(
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[dict[str, bool]] = None
_gpu_p2p_access_cache: dict[str, bool] | None = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional, Union
from weakref import WeakValueDictionary
import torch
@@ -75,7 +74,7 @@ class All2AllManagerBase:
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
def max_sms_used(self) -> int | None:
return None # None means it could use the whole GPU
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
@@ -96,8 +95,8 @@ class DeviceCommunicatorBase:
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
self.device = device or torch.device("cpu")
@@ -123,7 +122,7 @@ class DeviceCommunicatorBase:
self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
self.all2all_manager: All2AllManagerBase | None = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
@@ -156,10 +155,10 @@ class DeviceCommunicatorBase:
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
input_: torch.Tensor | list[torch.Tensor],
dim: int = 0,
sizes: Optional[list[int]] = None,
) -> Union[torch.Tensor, list[torch.Tensor]]:
sizes: list[int] | None = None,
) -> torch.Tensor | list[torch.Tensor]:
raise NotImplementedError
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
@@ -196,13 +195,13 @@ class DeviceCommunicatorBase:
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
) -> torch.Tensor:
raise NotImplementedError
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
@@ -231,7 +230,7 @@ class DeviceCommunicatorBase:
output_tensor = None
return output_tensor
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
"""Sends a tensor to the destination rank in a blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
@@ -239,7 +238,7 @@ class DeviceCommunicatorBase:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Optional, Union
from typing import Any
import torch
from torch.distributed import ProcessGroup
@@ -18,8 +18,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
@@ -38,7 +38,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
@@ -99,7 +99,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
tensor_dict: dict[str, torch.Tensor | Any],
dst: int,
) -> None:
return self.dist_module.send_tensor_dict(tensor_dict, dst)
@@ -107,7 +107,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
@@ -140,16 +140,16 @@ class _CPUSHMDistributed:
return handle
def all_reduce(
self, input: torch.Tensor, group: Optional[ProcessGroup] = None
self, input: torch.Tensor, group: ProcessGroup | None = None
) -> None:
torch.ops._C.shm_allreduce(self.handle, input)
def gather(
self,
input: torch.Tensor,
gather_list: Optional[list[torch.Tensor]],
gather_list: list[torch.Tensor] | None,
dst: int = -1,
group: Optional[ProcessGroup] = None,
group: ProcessGroup | None = None,
) -> None:
# Note: different from the torch gather, here we use local dst rank.
torch.ops._C.shm_gather(
@@ -163,13 +163,13 @@ class _CPUSHMDistributed:
self,
output: torch.Tensor,
input: torch.Tensor,
group: Optional[ProcessGroup] = None,
group: ProcessGroup | None = None,
) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
tensor_dict: dict[str, torch.Tensor | Any],
dst: int,
) -> None:
key_list = list(tensor_dict.keys())
@@ -191,7 +191,7 @@ class _CPUSHMDistributed:
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
) -> dict[str, torch.Tensor | Any]:
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
value_list: list[torch.Tensor] = tensor_list[:-1]

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from torch.distributed import ProcessGroup
@@ -26,8 +25,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
@@ -54,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
)
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
self.pynccl_comm: Optional[PyNcclCommunicator] = None
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
@@ -63,9 +62,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
if is_symmetric_memory_enabled():
register_nccl_symmetric_ops(self.pynccl_comm)
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
self.ca_comm: CustomAllreduce | None = None
self.qr_comm: QuickAllReduce | None = None
self.symm_mem_comm: SymmMemCommunicator | None = None
if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
@@ -201,7 +200,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output.movedim(0, dim).contiguous()
def reduce_scatterv(
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
@@ -235,7 +234,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
"""Sends a tensor to the destination rank in a blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
@@ -248,7 +247,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
@@ -274,9 +273,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
input_: torch.Tensor | list[torch.Tensor],
dim: int = 0,
sizes: Optional[list[int]] = None,
sizes: list[int] | None = None,
):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
@@ -289,7 +288,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None
def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None):
def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size

View File

@@ -7,7 +7,7 @@ convenient for use when we just need to call a few functions.
import ctypes
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
@@ -36,7 +36,7 @@ class Function:
argtypes: list[Any]
def find_loaded_library(lib_name) -> Optional[str]:
def find_loaded_library(lib_name) -> str | None:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
@@ -113,7 +113,7 @@ class CudaRTLibrary:
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
def __init__(self, so_file: str | None = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Optional, Union
import torch
import torch.distributed as dist
@@ -55,7 +54,7 @@ class CustomAllreduce:
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
device: int | str | torch.device,
max_size=8192 * 1024,
symm_mem_enabled=False,
) -> None:
@@ -260,7 +259,7 @@ class CustomAllreduce:
)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
def custom_all_reduce(self, input: torch.Tensor) -> torch.Tensor | None:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
@@ -292,8 +291,8 @@ class CustomAllreduce:
@staticmethod
def create_shared_buffer(
size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False,
group: ProcessGroup | None = None,
uncached: bool | None = False,
) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
@@ -313,8 +312,8 @@ class CustomAllreduce:
@staticmethod
def free_shared_buffer(
pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None,
group: ProcessGroup | None = None,
rank: int | None = None,
) -> None:
if rank is None:
rank = dist.get_rank(group=group)

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
# ===================== import region =====================
import torch
@@ -59,9 +58,9 @@ def register_nccl_symmetric_ops(pynccl_comm):
class PyNcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
group: ProcessGroup | StatelessProcessGroup,
device: int | str | torch.device,
library_path: str | None = None,
):
"""
Args:

View File

@@ -3,7 +3,7 @@
import atexit
import contextlib
import tempfile
from typing import Any, Optional
from typing import Any
import torch
from packaging import version
@@ -141,7 +141,7 @@ class nccl_symm_mem_context:
or version.parse(torch.__version__) < version.parse("2.8.0.a0")
)
if self.disabled:
self.pynccl_comm: Optional[PyNcclCommunicator] = None
self.pynccl_comm: PyNcclCommunicator | None = None
self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = (
contextlib.nullcontext()
)

View File

@@ -25,7 +25,7 @@
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import torch
from torch.distributed import ReduceOp
@@ -305,7 +305,7 @@ class NCCLLibrary:
# to the corresponding dictionary
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
def __init__(self, so_file: str | None = None):
so_file = so_file or find_nccl_library()
try:

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Union
import torch
import torch.distributed as dist
@@ -58,9 +57,7 @@ class QuickAllReduce:
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
}
def __init__(
self, group: ProcessGroup, device: Union[int, str, torch.device]
) -> None:
def __init__(self, group: ProcessGroup, device: int | str | torch.device) -> None:
"""
Custom allreduce provides non-destructive acceleration and is
available for CUDA and ROCm MI300 series.

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from typing import Any, Optional
from typing import Any
import ray
import torch
@@ -27,15 +27,15 @@ class RayPPCommunicator(Communicator):
This class is not thread-safe.
"""
_comm: Optional[DeviceCommunicatorBase]
_comm: DeviceCommunicatorBase | None
def __init__(
self,
world_size: int,
comm_id: Any,
rank: Optional[int],
rank: int | None,
actor_handles: list["ray.actor.ActorHandle"],
cuda_stream: Optional[torch.cuda.Stream],
cuda_stream: torch.cuda.Stream | None,
use_communication_streams: bool = False,
):
"""
@@ -56,7 +56,7 @@ class RayPPCommunicator(Communicator):
This is not supported.
"""
self._world_size = world_size
self._rank: Optional[int] = None
self._rank: int | None = None
self._actor_handles = actor_handles
if use_communication_streams:
raise NotImplementedError("use_communication_streams is not supported")
@@ -143,7 +143,7 @@ class RayPPCommunicator(Communicator):
else:
raise ValueError(f"Actor {actor} not found in communicator group")
def get_self_rank(self) -> Optional[int]:
def get_self_rank(self) -> int | None:
"""
Return this actor's rank.
"""

View File

@@ -7,7 +7,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from threading import Event
from typing import Any, Optional, Union
from typing import Any
from unittest.mock import patch
import torch
@@ -80,7 +80,7 @@ class ShmRingBuffer:
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None,
name: str | None = None,
):
"""
A shared memory ring buffer implementation for broadcast communication.
@@ -213,9 +213,9 @@ class ShmRingBuffer:
class Handle:
local_reader_ranks: list[int] = field(default_factory=list)
buffer_handle: Optional[tuple[int, int, int, str]] = None
local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None
buffer_handle: tuple[int, int, int, str] | None = None
local_subscribe_addr: str | None = None
remote_subscribe_addr: str | None = None
remote_addr_ipv6: bool = False
@@ -224,10 +224,10 @@ class MessageQueue:
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[list[int]] = None,
local_reader_ranks: list[int] | None = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,
connect_ip: str | None = None,
):
if local_reader_ranks is None:
local_reader_ranks = list(range(n_local_reader))
@@ -384,7 +384,7 @@ class MessageQueue:
assert recv == b"READY"
@contextmanager
def acquire_write(self, timeout: Optional[float] = None):
def acquire_write(self, timeout: float | None = None):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
@@ -444,8 +444,8 @@ class MessageQueue:
@contextmanager
def acquire_read(
self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
timeout: float | None = None,
cancel: Event | None = None,
indefinite: bool = False,
):
assert self._is_local_reader, "Only readers can acquire read"
@@ -502,7 +502,7 @@ class MessageQueue:
self._read_spin_timer.record_activity()
break
def enqueue(self, obj, timeout: Optional[float] = None):
def enqueue(self, obj, timeout: float | None = None):
"""Write to message queue with optional timeout (in seconds)"""
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
@@ -520,8 +520,8 @@ class MessageQueue:
def dequeue(
self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
timeout: float | None = None,
cancel: Event | None = None,
indefinite: bool = False,
):
"""Read from message queue with optional timeout (in seconds)"""
@@ -542,7 +542,7 @@ class MessageQueue:
return obj
@staticmethod
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
def recv(socket: zmq.Socket, timeout: float | None) -> Any:
timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms):
raise TimeoutError
@@ -558,7 +558,7 @@ class MessageQueue:
@staticmethod
def create_from_process_group(
pg: Union[ProcessGroup, StatelessProcessGroup],
pg: ProcessGroup | StatelessProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0,

View File

@@ -3,13 +3,13 @@
import pickle
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain
from multiprocessing import shared_memory
from multiprocessing.synchronize import Lock as LockType
from typing import Any, Callable, Optional, Union
from typing import Any
from unittest.mock import patch
import torch
@@ -109,7 +109,7 @@ class SingleWriterShmRingBuffer:
def __init__(
self,
data_buffer_size: int,
name: Optional[str] = None,
name: str | None = None,
create: bool = False,
):
self.data_buffer_size = data_buffer_size
@@ -252,7 +252,7 @@ class SingleWriterShmRingBuffer:
def free_buf(
self,
is_free_fn: Callable[[int, memoryview], bool],
nbytes: Optional[int] = None,
nbytes: int | None = None,
) -> Iterable[int]:
"""
Free a buffer of the given size. This is a no-op in shared memory,
@@ -340,9 +340,7 @@ class MsgpackSerde(ObjectSerde):
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
self._mm_kwargs_item_cls = MultiModalKwargsItem
def serialize(
self, value: Any
) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]:
len_arr = None
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
type_name = type(value).__name__
@@ -396,7 +394,7 @@ class ShmObjectStorageHandle:
n_readers: int
ring_buffer_handle: tuple[int, str]
serde_class: type[ObjectSerde]
reader_lock: Optional[LockType]
reader_lock: LockType | None
class SingleWriterShmObjectStorage:
@@ -444,7 +442,7 @@ class SingleWriterShmObjectStorage:
n_readers: int,
ring_buffer: SingleWriterShmRingBuffer,
serde_class: type[ObjectSerde] = MsgpackSerde,
reader_lock: Optional[LockType] = None,
reader_lock: LockType | None = None,
):
"""
Initialize the object storage.
@@ -492,7 +490,7 @@ class SingleWriterShmObjectStorage:
def copy_to_buffer(
self,
data: Union[bytes, list[bytes]],
data: bytes | list[bytes],
data_bytes: int,
metadata: bytes,
md_bytes: int,

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
@@ -31,10 +30,10 @@ class SymmMemCommunicator:
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
device: int | str | torch.device,
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None,
force_multimem: bool | None = None,
max_size_override: int | None = None,
):
self.disabled = True
@@ -108,8 +107,8 @@ class SymmMemCommunicator:
return inp_size < self.max_size
def all_reduce(
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
self, inp: torch.Tensor, *, out: torch.Tensor | None = None
) -> torch.Tensor | None:
if not self.should_use_symm_mem(inp):
return None
if out is None:

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
import torch
from torch.distributed import ProcessGroup
@@ -39,8 +38,8 @@ class TpuCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.distributed as dist
@@ -19,8 +18,8 @@ class XpuCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
@@ -45,7 +44,7 @@ class XpuCommunicator(DeviceCommunicatorBase):
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)