[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-04-21 23:44:32 -04:00
committed by GitHub
parent c9acbf1141
commit 7b8a2ab76f
5 changed files with 331 additions and 171 deletions

View File

@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else:
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu" and layer.expert_map is None):
and layer.activation == "silu"):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@@ -510,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor:
assert activation == "silu"
assert global_num_experts == layer.w13_weight.shape[0]
assert expert_map is None
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -542,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)