Upstream Llama4 Support to Main (#16113)
Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com> Signed-off-by: Chris Thi <chris.c.thi@gmail.com> Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Keyun Tong <tongkeyun@gmail.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Xiaodong Wang <xdwang@meta.com> Signed-off-by: Yang Chen <yangche@fb.com> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Lu Fang <fanglu@fb.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -980,6 +981,7 @@ def inplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@@ -1010,6 +1012,7 @@ def outplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@@ -1023,10 +1026,11 @@ def outplace_fused_experts(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
@@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
@@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
@@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
|
||||
Reference in New Issue
Block a user