[Kernels] Overlap shared experts with send/recv (#23273)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-03 12:35:18 -04:00
committed by GitHub
parent fa4311d85f
commit e9b92dcd89
32 changed files with 885 additions and 227 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
from typing import Callable, Optional, Union, final
import torch
@@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
raise NotImplementedError
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
@@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
) -> PrepareResultType:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
@@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
"""
return False
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
- expert_map: A tensor mapping expert indices from the global expert
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `prepare`, e.g.
receiver = obj.prepare_async(...)
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to:
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
"""
raise NotImplementedError
@abstractmethod
def finalize(
self,
@@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
@@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
@@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
shared_output: torch.Tensor
if (not self.prepare_finalize.supports_async()
or self.shared_experts is None):
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
else:
# Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
assert self.shared_experts is not None
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
@@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
return output
if self.shared_experts is None:
return output
else:
return shared_output, output