Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -18,8 +18,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
@@ -38,7 +38,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
@@ -99,7 +99,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int,
|
||||
) -> None:
|
||||
return self.dist_module.send_tensor_dict(tensor_dict, dst)
|
||||
@@ -107,7 +107,7 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
|
||||
@@ -140,16 +140,16 @@ class _CPUSHMDistributed:
|
||||
return handle
|
||||
|
||||
def all_reduce(
|
||||
self, input: torch.Tensor, group: Optional[ProcessGroup] = None
|
||||
self, input: torch.Tensor, group: ProcessGroup | None = None
|
||||
) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
gather_list: list[torch.Tensor] | None,
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(
|
||||
@@ -163,13 +163,13 @@ class _CPUSHMDistributed:
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group: ProcessGroup | None = None,
|
||||
) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int,
|
||||
) -> None:
|
||||
key_list = list(tensor_dict.keys())
|
||||
@@ -191,7 +191,7 @@ class _CPUSHMDistributed:
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int,
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
|
||||
|
||||
value_list: list[torch.Tensor] = tensor_list[:-1]
|
||||
|
||||
Reference in New Issue
Block a user