[Feature][OCP MX] Support mxfp6 and mixed mxfp6-mxfp4 (#21166)

This commit is contained in:
fxmarty-amd
2025-10-07 15:35:26 +02:00
committed by GitHub
parent 08d26a1b7e
commit 41f1cf38f2
18 changed files with 656 additions and 180 deletions

View File

@@ -42,6 +42,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
@@ -1323,7 +1325,7 @@ def inplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@@ -1350,7 +1352,7 @@ def inplace_fused_experts(
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
@@ -1378,7 +1380,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@@ -1420,7 +1422,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@@ -1447,7 +1449,7 @@ def outplace_fused_experts(
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
ocp_mx_scheme,
per_channel_quant,
global_num_experts,
expert_map,
@@ -1474,7 +1476,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@@ -1599,7 +1601,7 @@ def fused_experts(
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
@@ -1622,7 +1624,7 @@ GELU_NO_MUL: str = activation_without_mul("gelu")
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_mxfp4_w4a4: bool,
ocp_mx_scheme: Optional[str],
) -> Union[None, torch.dtype, str]:
"""
Get the quantization type based on the quantization strategy flags.
@@ -1635,8 +1637,12 @@ def _get_config_quant_dtype(
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
return "mxfp4"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
return "mxfp6_e3m2"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
return "mxfp6_e2m3"
return None
@@ -1653,7 +1659,7 @@ def fused_experts_impl(
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
@@ -1670,9 +1676,23 @@ def fused_experts_impl(
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif use_mxfp4_w4a4:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in {
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
@@ -1699,7 +1719,7 @@ def fused_experts_impl(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
ocp_mx_scheme=ocp_mx_scheme,
dtype=hidden_states.dtype,
)
@@ -1708,7 +1728,7 @@ def fused_experts_impl(
quant_dtype = _get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_mxfp4_w4a4=use_mxfp4_w4a4,
ocp_mx_scheme=ocp_mx_scheme,
)
get_config_func = functools.partial(
@@ -1748,12 +1768,40 @@ def fused_experts_impl(
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if use_mxfp4_w4a4:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (