Upstream Llama4 Support to Main (#16113)
Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com> Signed-off-by: Chris Thi <chris.c.thi@gmail.com> Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Keyun Tong <tongkeyun@gmail.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Xiaodong Wang <xdwang@meta.com> Signed-off-by: Yang Chen <yangche@fb.com> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Lu Fang <fanglu@fb.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
@@ -96,8 +97,14 @@ def cutlass_moe_fp8(
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
if apply_router_weight_on_input:
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
a = a * topk_weights.to(out_dtype)
|
||||
|
||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||
@@ -139,6 +146,8 @@ def cutlass_moe_fp8(
|
||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||
ab_strides2, c_strides2)
|
||||
|
||||
return (c2[c_map].view(m, topk, k) *
|
||||
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
# Gather tokens
|
||||
c2 = c2[c_map].view(m, topk, k)
|
||||
if not apply_router_weight_on_input:
|
||||
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
return c2.sum(dim=1)
|
||||
|
||||
Reference in New Issue
Block a user