Bugfix: Cutlass FP8 FusedMoE bad scaling factors (#27255)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -463,6 +463,10 @@ def fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
a1_gscale: torch.Tensor | None = None,
|
||||
a2_gscale: torch.Tensor | None = None,
|
||||
g1_alphas: torch.Tensor | None = None,
|
||||
g2_alphas: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for fp8 activations and fp8 weights.
|
||||
@@ -470,9 +474,13 @@ def fp8_w8a8_moe_quant_config(
|
||||
return FusedMoEQuantConfig.make(
|
||||
torch.float8_e4m3fn,
|
||||
w1_scale=w1_scale,
|
||||
g1_alphas=g1_alphas,
|
||||
w2_scale=w2_scale,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_scale=a1_scale,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_scale=a2_scale,
|
||||
a2_gscale=a2_gscale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
|
||||
@@ -170,7 +170,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
if not self.use_dp:
|
||||
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
|
||||
return a1, None, None, topk_ids, topk_weights
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
@@ -181,11 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
|
||||
if self.use_dp:
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
|
||||
@@ -567,9 +567,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a2_gscale=1.0 / layer.w2_input_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
@@ -1138,8 +1142,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support,
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
|
||||
detect_nvfp4_moe_support, # noqa: E501
|
||||
)
|
||||
|
||||
super().__init__(moe)
|
||||
|
||||
Reference in New Issue
Block a user