[MoE Refactor] Mxfp4 oracle rebased (#37128)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
@@ -14,6 +15,7 @@ if not has_triton_kernels():
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
@@ -303,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
if current_platform.is_device_capability_family(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
|
||||
if a_dtype == "bf16" and w_dtype == "mx4":
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=pc1,
|
||||
|
||||
Reference in New Issue
Block a user