[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com> Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
c9acbf1141
commit
7b8a2ab76f
@@ -15,7 +15,7 @@ def cutlass_moe_fp8(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -57,12 +58,19 @@ def cutlass_moe_fp8(
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- out_dtype (torch.Tensor): The output tensor type.
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
is -1, it means that this Rank is not responsible for global
|
||||
expert-id i.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.float8_e4m3fn
|
||||
assert w2_q.dtype == torch.float8_e4m3fn
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
@@ -96,7 +104,13 @@ def cutlass_moe_fp8(
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
local_topk_ids = topk_ids_
|
||||
if expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
|
||||
expert_map[topk_ids_], -1)
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
@@ -120,10 +134,23 @@ def cutlass_moe_fp8(
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
a_map_initializer = torch.empty
|
||||
c2_initializer = torch.empty
|
||||
if expert_map is not None:
|
||||
# With expert_map each Rank processes only a subset of experts. As
|
||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||
# zeros for correctness.
|
||||
a_map_initializer = torch.zeros
|
||||
c2_initializer = torch.zeros
|
||||
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
a_map = a_map_initializer((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
c_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, num_experts, n,
|
||||
k)
|
||||
|
||||
@@ -131,7 +158,7 @@ def cutlass_moe_fp8(
|
||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||
|
||||
Reference in New Issue
Block a user