[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional, final
|
||||
from typing import Callable, Optional, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
# as big as the number of local experts with the information about the
|
||||
# number of tokens assigned to each local expert.
|
||||
# - Optional dispatched expert topk IDs
|
||||
# - Optional dispatched expert topk weight
|
||||
#
|
||||
# See `prepare` method below.
|
||||
#
|
||||
PrepareResultType = tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
) -> PrepareResultType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
@@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not this class implements prepare_async.
|
||||
"""
|
||||
return False
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> ReceiverType:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
sure the quantization is consistent for both gemms.
|
||||
- topk_ids: The topk ids.
|
||||
- topk_weights: The topk weights.
|
||||
- num_experts: The total number of experts in the global expert space.
|
||||
- expert_map: A tensor mapping expert indices from the global expert
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `prepare`, e.g.
|
||||
|
||||
receiver = obj.prepare_async(...)
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||
|
||||
is equivalent to:
|
||||
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
@@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
@@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
of weights, w1 and w2, and top-k gating mechanism.
|
||||
@@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if (not self.prepare_finalize.supports_async()
|
||||
or self.shared_experts is None):
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
@@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
return output
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
return shared_output, output
|
||||
|
||||
Reference in New Issue
Block a user