Add pynccl all-gatherv and reducescatterv (#20154)

Signed-off-by: Trevor Morris <tmorris@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Trevor Morris
2025-07-11 18:59:23 -07:00
committed by GitHub
parent fc0f41d10a
commit a8593237c0
6 changed files with 284 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
from typing import Optional, Union
from weakref import WeakValueDictionary
import torch
@@ -138,6 +138,14 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
@@ -172,6 +180,12 @@ class DeviceCommunicatorBase:
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
raise NotImplementedError
def gather(self,
input_: torch.Tensor,
dst: int = 0,