Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@@ -75,7 +74,7 @@ class All2AllManagerBase:
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
def max_sms_used(self) -> int | None:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
|
||||
@@ -96,8 +95,8 @@ class DeviceCommunicatorBase:
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
@@ -123,7 +122,7 @@ class DeviceCommunicatorBase:
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
@@ -156,10 +155,10 @@ class DeviceCommunicatorBase:
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
sizes: list[int] | None = None,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
@@ -196,13 +195,13 @@ class DeviceCommunicatorBase:
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
@@ -231,7 +230,7 @@ class DeviceCommunicatorBase:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
@@ -239,7 +238,7 @@ class DeviceCommunicatorBase:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
|
||||
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
|
||||
Reference in New Issue
Block a user