[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
This commit is contained in:
@@ -4,7 +4,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import (
|
||||
make_dummy_moe_config,
|
||||
make_test_quant_config,
|
||||
make_test_weights,
|
||||
)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
@@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape,
|
||||
deep_gemm_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
@@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe(
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||
m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
@@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
deep_gemm_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
|
||||
fused_experts=TritonOrDeepGemmExperts(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
|
||||
return deep_gemm_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
|
||||
Reference in New Issue
Block a user