[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE (#30014)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
@@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
@@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
@@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
@@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
|
||||
tensors_to_gather = [hidden_states, router_logits]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
if extra_tensors is not None:
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user