Add pynccl all-gatherv and reducescatterv (#20154)
Signed-off-by: Trevor Morris <tmorris@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@@ -138,6 +138,14 @@ class DeviceCommunicatorBase:
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
@@ -172,6 +180,12 @@ class DeviceCommunicatorBase:
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1,
|
||||
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
|
||||
Reference in New Issue
Block a user