Support Llama 4 for cutlass_moe_fp4 (#20453)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-07-10 04:53:38 +09:00
committed by GitHub
parent e59ba9e142
commit 31b96d1c64
3 changed files with 80 additions and 74 deletions

View File

@@ -411,13 +411,23 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
device: torch.device):
def cutlass_moe_fp4(a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
apply_router_weight_on_input: bool = False):
"""
MoE implementation for FP4 Inputs
@@ -480,6 +490,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
if apply_router_weight_on_input:
# TODO: this only works for topK=1, will need to update for topK>1
assert num_topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a.mul_(topk_weights.to(out_dtype))
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
@@ -517,8 +533,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
del int_fp4, int_blockscale
c2 = ops.shuffle_rows(c2, c_map)
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
if not apply_router_weight_on_input:
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
else:
out = c2.view(m, num_topk, k).sum(dim=1)
return out.to(dtype=out_dtype)