Bugfix: Cutlass FP8 FusedMoE bad scaling factors (#27255)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -6,7 +6,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
@@ -22,10 +25,10 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
100
|
||||
90
|
||||
):
|
||||
pytest.skip(
|
||||
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
"Supported for sm >= 90",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
@@ -131,6 +134,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Test is only supported for sm >= 100")
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
@@ -184,9 +189,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
|
||||
)
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@@ -216,9 +218,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
|
||||
w2_scale=td.w2_weight_scale,
|
||||
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
|
||||
a1_scale=td.a1_scale,
|
||||
a1_gscale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
a2_gscale=1.0 / td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
@@ -238,6 +244,12 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return quant_config
|
||||
|
||||
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
|
||||
td.layer.quant_method = td.layer
|
||||
|
||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||
td.hidden_states,
|
||||
td.layer,
|
||||
|
||||
Reference in New Issue
Block a user