[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -15,11 +15,14 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
@@ -187,14 +190,9 @@ def test_fused_moe(
|
||||
#
|
||||
# Setup test functions
|
||||
#
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None)
|
||||
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
|
||||
|
||||
def m_fused_moe(
|
||||
a: torch.Tensor,
|
||||
@@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
if weight_bits == 4:
|
||||
quant_config_builder = int4_w4a16_moe_quant_config
|
||||
else:
|
||||
assert weight_bits == 8
|
||||
quant_config_builder = int8_w8a16_moe_quant_config
|
||||
|
||||
quant_config = quant_config_builder(w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
@@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
quant_config=quant_config)
|
||||
torch_output = torch_moe(a,
|
||||
w1_ref,
|
||||
w2_ref,
|
||||
|
||||
Reference in New Issue
Block a user