[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-04-21 23:44:32 -04:00
committed by GitHub
parent c9acbf1141
commit 7b8a2ab76f
5 changed files with 331 additions and 171 deletions

View File

@@ -15,7 +15,7 @@ def cutlass_moe_fp8(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_ids_: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
@@ -57,12 +58,19 @@ def cutlass_moe_fp8(
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
@@ -96,7 +104,13 @@ def cutlass_moe_fp8(
k = w1_q.size(1)
n = w2_q.size(1)
topk = topk_ids.size(1)
local_topk_ids = topk_ids_
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
expert_map[topk_ids_], -1)
topk = local_topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
@@ -120,10 +134,23 @@ def cutlass_moe_fp8(
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
a_map_initializer = torch.empty
c2_initializer = torch.empty
if expert_map is not None:
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer = torch.zeros
c2_initializer = torch.zeros
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
a_map = a_map_initializer((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
@@ -131,7 +158,7 @@ def cutlass_moe_fp8(
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,