[Feature][Quantization] MXFP4 support for MOE models (#17888)
Signed-off-by: Felix Marty <felmarty@amd.com> Signed-off-by: Bowen Bao <bowenbao@amd.com> Signed-off-by: Felix Marty <Felix.Marty@amd.com> Co-authored-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
@@ -27,6 +27,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -973,13 +975,16 @@ def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False) -> Optional[str]:
|
||||
use_fp8_w8a8: Optional[bool] = False,
|
||||
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w4a16"
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4_w4a4"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
@@ -998,6 +1003,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1011,9 +1017,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
per_channel_quant, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -1028,6 +1034,7 @@ def inplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1062,6 +1069,7 @@ def outplace_fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1075,10 +1083,10 @@ def outplace_fused_experts(
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, per_channel_quant,
|
||||
global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
use_int4_w4a16, use_mxfp4_w4a4,
|
||||
per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -1092,6 +1100,7 @@ def outplace_fused_experts_fake(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1145,6 +1154,7 @@ def fused_experts(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1203,6 +1213,7 @@ def fused_experts(
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@@ -1228,6 +1239,7 @@ def fused_experts_impl(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1243,6 +1255,9 @@ def fused_experts_impl(
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
elif use_mxfp4_w4a4:
|
||||
# 16bit activation and fp4x2 packed weight
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
|
||||
@@ -1268,12 +1283,14 @@ def fused_experts_impl(
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
@@ -1313,6 +1330,13 @@ def fused_experts_impl(
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if use_mxfp4_w4a4:
|
||||
# Weight has to be dequantized for mxfp4 emulation.
|
||||
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
|
||||
w1_scale = None
|
||||
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
|
||||
w2_scale = None
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
@@ -1429,6 +1453,7 @@ def fused_moe(
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
@@ -1470,6 +1495,9 @@ def fused_moe(
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
|
||||
OCP MXFP4 activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
@@ -1513,6 +1541,7 @@ def fused_moe(
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@@ -1533,6 +1562,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
@@ -1542,6 +1572,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
@@ -1550,6 +1581,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@@ -1627,6 +1659,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
@@ -1718,6 +1751,7 @@ def modular_triton_fused_moe(
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
@@ -1728,6 +1762,7 @@ def modular_triton_fused_moe(
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user