[Kernels] Overlap shared experts with combine instead of dispatch (#24254)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not this class implements prepare_async.
|
||||
Indicates whether or not this class implements prepare_async and
|
||||
finalize_async.
|
||||
"""
|
||||
return False
|
||||
|
||||
@@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
- output: The output tensor, written in place. Must be (M, K) shape.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
- topk_weights: The weights to be applied to the fused_experts_output.
|
||||
- topk_ids: The topk_ids.
|
||||
- apply_router_weight_on_input: When False, apply the weights to
|
||||
fused_expert_output.
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `finalize`, e.g.
|
||||
|
||||
receiver = obj.finalize_async(output, ...)
|
||||
... output not valid yet ...
|
||||
receiver()
|
||||
... output valid here ...
|
||||
|
||||
is equivalent to:
|
||||
|
||||
obj.finalize(output, ...)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
@@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
if inplace and self.shared_experts is None:
|
||||
output = a1
|
||||
else:
|
||||
output = torch.zeros_like(a1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
assert not dbo_enabled()
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
@@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
@@ -900,16 +931,42 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
recv_hook = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
assert recv_hook is not None
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
recv_hook()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
Reference in New Issue
Block a user