[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -8,7 +8,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
|
||||
topk_ids):
|
||||
"""This function performs fused moe with per-column int8 quantization
|
||||
using native torch."""
|
||||
|
||||
@@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
@@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
|
||||
topk_weights, topk_ids)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True, # Using int8-w8a8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Check results
|
||||
|
||||
Reference in New Issue
Block a user