[Feature][OCP MX] Support mxfp6 and mixed mxfp6-mxfp4 (#21166)
This commit is contained in:
@@ -42,6 +42,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
@@ -1323,7 +1325,7 @@ def inplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1350,7 +1352,7 @@ def inplace_fused_experts(
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
use_mxfp4_w4a4,
|
||||
ocp_mx_scheme,
|
||||
per_channel_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
@@ -1378,7 +1380,7 @@ def inplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1420,7 +1422,7 @@ def outplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1447,7 +1449,7 @@ def outplace_fused_experts(
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
use_mxfp4_w4a4,
|
||||
ocp_mx_scheme,
|
||||
per_channel_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
@@ -1474,7 +1476,7 @@ def outplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1599,7 +1601,7 @@ def fused_experts(
|
||||
use_int8_w8a8=quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=quant_config.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
|
||||
ocp_mx_scheme=quant_config.ocp_mx_scheme,
|
||||
per_channel_quant=quant_config.per_act_token_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@@ -1622,7 +1624,7 @@ GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
ocp_mx_scheme: Optional[str],
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
"""
|
||||
Get the quantization type based on the quantization strategy flags.
|
||||
@@ -1635,8 +1637,12 @@ def _get_config_quant_dtype(
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif use_mxfp4_w4a4:
|
||||
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
|
||||
return "mxfp4"
|
||||
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
|
||||
return "mxfp6_e3m2"
|
||||
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
|
||||
return "mxfp6_e2m3"
|
||||
return None
|
||||
|
||||
|
||||
@@ -1653,7 +1659,7 @@ def fused_experts_impl(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
ocp_mx_scheme: Optional[str] = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1670,9 +1676,23 @@ def fused_experts_impl(
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
elif use_mxfp4_w4a4:
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch"
|
||||
elif ocp_mx_scheme is not None:
|
||||
if ocp_mx_scheme in {
|
||||
"w_mxfp4_a_mxfp4",
|
||||
"w_mxfp4_a_mxfp6_e3m2",
|
||||
"w_mxfp4_a_mxfp6_e2m3",
|
||||
}:
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
|
||||
elif ocp_mx_scheme in {
|
||||
"w_mxfp6_e3m2_a_mxfp6_e3m2",
|
||||
"w_mxfp6_e2m3_a_mxfp6_e2m3",
|
||||
}:
|
||||
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
|
||||
"hidden size mismatch"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
else:
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
|
||||
@@ -1699,7 +1719,7 @@ def fused_experts_impl(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
@@ -1708,7 +1728,7 @@ def fused_experts_impl(
|
||||
quant_dtype = _get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
@@ -1748,12 +1768,40 @@ def fused_experts_impl(
|
||||
|
||||
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
|
||||
if use_mxfp4_w4a4:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
# and for which we have a native OCP mx fused MOE kernel,
|
||||
# this dequantization step should not be done.
|
||||
if ocp_mx_scheme in {
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
|
||||
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
|
||||
}:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
|
||||
w1 = dequant_mxfp6(
|
||||
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp6(
|
||||
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
|
||||
)
|
||||
w2_scale = None
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (
|
||||
|
||||
Reference in New Issue
Block a user