[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
w1=w1_tri,
|
||||
@@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
@@ -336,6 +341,13 @@ def batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens,
|
||||
@@ -344,19 +356,12 @@ def batched_moe(
|
||||
rank=0,
|
||||
),
|
||||
BatchedOAITritonExperts(
|
||||
None,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
extra_expert_args = {
|
||||
"w1_bias": w1_bias,
|
||||
"w2_bias": w2_bias,
|
||||
}
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
|
||||
|
||||
return fused_experts(
|
||||
@@ -365,7 +370,6 @@ def batched_moe(
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
extra_expert_args=extra_expert_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user