Support Llama 4 for cutlass_moe_fp4 (#20453)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -411,13 +411,23 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
|
||||
device: torch.device):
|
||||
def cutlass_moe_fp4(a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
@@ -480,6 +490,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert num_topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a.mul_(topk_weights.to(out_dtype))
|
||||
|
||||
# problem shapes should have [m, n, k]
|
||||
# Note that problem sizes are based on logical number of elements.
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
@@ -517,8 +533,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
del int_fp4, int_blockscale
|
||||
|
||||
c2 = ops.shuffle_rows(c2, c_map)
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
|
||||
if not apply_router_weight_on_input:
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
|
||||
else:
|
||||
out = c2.view(m, num_topk, k).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user