[Kernel] DeepEP dispatch-combine kernel integration (#18434)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
01eee40536
commit
fa98d77773
@@ -94,7 +94,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@@ -113,6 +114,10 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -138,6 +143,27 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can processes only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
@@ -261,6 +287,61 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def _do_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor, # input to forward fn
|
||||
a1q: torch.Tensor, # output of prepare fn
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
# Use a1 here to decipher the correct workspace datatype
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -315,49 +396,48 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (topk_weights if _expert_topk_weights is None else
|
||||
_expert_topk_weights)
|
||||
|
||||
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
|
||||
expert_map, apply_router_weight_on_input)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
fused_out = None
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._do_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
|
||||
Reference in New Issue
Block a user