[Feature] support sequence parallelism using compilation pass (#16155)

Signed-off-by: cascade812 <cascade812@outlook.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
cascade
2025-04-27 06:29:35 -07:00
committed by GitHub
parent ed7a29d9f8
commit 690fe019f0
21 changed files with 1072 additions and 44 deletions

View File

@@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def gather(self,
input_: torch.Tensor,
dst: int = 0,

View File

@@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""