Fix Flashinfer CUTLASS MOE Allgather (#21963)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
@@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group]
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group], (
|
||||
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
|
||||
output_size = (sum(sizes), ) + input_size[1:]
|
||||
else:
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
|
||||
Reference in New Issue
Block a user