[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com> Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
c9acbf1141
commit
7b8a2ab76f
@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
else:
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
and layer.activation == "silu" and layer.expert_map is None):
|
||||
and layer.activation == "silu"):
|
||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
@@ -510,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu"
|
||||
assert global_num_experts == layer.w13_weight.shape[0]
|
||||
assert expert_map is None
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@@ -542,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
out_dtype=x.dtype,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user