[Kernels] Overlap shared experts with combine instead of dispatch (#24254)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-18 00:10:21 -04:00
committed by GitHub
parent 027d37df38
commit dc2979c585
4 changed files with 203 additions and 36 deletions

View File

@@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return False
@@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.
receiver = obj.finalize_async(output, ...)
... output not valid yet ...
receiver()
... output valid here ...
is equivalent to:
obj.finalize(output, ...)
"""
raise NotImplementedError
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
@@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module):
"""
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
if inplace and self.shared_experts is None:
output = a1
else:
output = torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
shared_output: torch.Tensor
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert not dbo_enabled()
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
@@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config,
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
@@ -900,16 +931,42 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input,
)
self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
shared_output: Optional[torch.Tensor] = None
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
else:
recv_hook = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
assert recv_hook is not None
dbo_register_recv_hook(recv_hook)
dbo_yield()
if not dbo_enabled():
recv_hook()
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output