Fix Flashinfer CUTLASS MOE Allgather (#21963)

Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
Shu Wang
2025-08-07 21:18:25 -05:00
committed by GitHub
parent a3b9c17b56
commit b2c8ce57c6
4 changed files with 71 additions and 27 deletions

View File

@@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group]
assert input_.shape[dim] == sizes[self.rank_in_group], (
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]