Fix Llama4 FlashInfer FP4 MoE issues (#22511)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
f7ad6a1eb3
commit
67c153b88a
@@ -170,8 +170,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"w1_scale and w2_scale must not "
|
"w1_scale and w2_scale must not "
|
||||||
"be None for FlashInferExperts")
|
"be None for FlashInferExperts")
|
||||||
|
|
||||||
assert not apply_router_weight_on_input
|
|
||||||
|
|
||||||
quant_scales = [
|
quant_scales = [
|
||||||
a1_gscale,
|
a1_gscale,
|
||||||
w1_scale.view(torch.int32),
|
w1_scale.view(torch.int32),
|
||||||
|
|||||||
@@ -60,7 +60,12 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
|
||||||
assert not apply_router_weight_on_input
|
if apply_router_weight_on_input:
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
|
assert topk == 1, \
|
||||||
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
(a1_gscale, use_dp, local_tokens) = extract_required_args(
|
(a1_gscale, use_dp, local_tokens) = extract_required_args(
|
||||||
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
|
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
|
||||||
|
|||||||
@@ -1299,8 +1299,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
output2_scale_scalar=layer.g2_alphas.data,
|
output2_scale_scalar=layer.g2_alphas.data,
|
||||||
num_experts=global_num_experts,
|
num_experts=global_num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
n_group=num_expert_group,
|
n_group=num_expert_group
|
||||||
topk_group=topk_group,
|
if num_expert_group is not None else 0,
|
||||||
|
topk_group=topk_group if topk_group is not None else 0,
|
||||||
intermediate_size=layer.intermediate_size_per_partition,
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
local_num_experts=layer.local_num_experts,
|
local_num_experts=layer.local_num_experts,
|
||||||
|
|||||||
Reference in New Issue
Block a user