[Quant] Support MXFP4 W4A16 for compressed-tensors MoE models (#32285)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
5
tests/evals/gsm8k/configs/Qwen3-30B-A3B-MXFP4A16.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen3-30B-A3B-MXFP4A16.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
model_name: nm-testing/Qwen3-30B-A3B-MXFP4A16
|
||||||
|
accuracy_threshold: 0.88
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
@@ -4,3 +4,4 @@ Llama-3-8B-Instruct-nonuniform-CT.yaml
|
|||||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||||
Qwen1.5-MoE-W4A16-CT.yaml
|
Qwen1.5-MoE-W4A16-CT.yaml
|
||||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||||
|
Qwen3-30B-A3B-MXFP4A16.yaml
|
||||||
@@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
|
mxfp4_w4a16_moe_quant_config,
|
||||||
nvfp4_moe_quant_config,
|
nvfp4_moe_quant_config,
|
||||||
nvfp4_w4a16_moe_quant_config,
|
nvfp4_w4a16_moe_quant_config,
|
||||||
)
|
)
|
||||||
@@ -193,6 +194,16 @@ def convert_to_nvfp4_moe_kernel_format(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_mxfp4_moe_quant_config(
|
||||||
|
w13_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
return mxfp4_w4a16_moe_quant_config(
|
||||||
|
w1_scale=w13_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_nvfp4_moe_quant_config(
|
def make_nvfp4_moe_quant_config(
|
||||||
backend: NvFp4MoeBackend,
|
backend: NvFp4MoeBackend,
|
||||||
w13_scale: torch.Tensor,
|
w13_scale: torch.Tensor,
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
|||||||
NvFp4MoeBackend,
|
NvFp4MoeBackend,
|
||||||
convert_to_nvfp4_moe_kernel_format,
|
convert_to_nvfp4_moe_kernel_format,
|
||||||
is_global_sf_supported_for_nvfp4_backend,
|
is_global_sf_supported_for_nvfp4_backend,
|
||||||
|
make_mxfp4_moe_quant_config,
|
||||||
make_nvfp4_moe_kernel,
|
make_nvfp4_moe_kernel,
|
||||||
make_nvfp4_moe_quant_config,
|
make_nvfp4_moe_quant_config,
|
||||||
select_nvfp4_moe_backend,
|
select_nvfp4_moe_backend,
|
||||||
@@ -79,6 +80,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
is_fp4_marlin_supported,
|
is_fp4_marlin_supported,
|
||||||
|
prepare_moe_fp4_layer_for_marlin,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
convert_bf16_scales_to_fp8,
|
convert_bf16_scales_to_fp8,
|
||||||
@@ -145,6 +147,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
input_quant = scheme_dict.get("input_activations")
|
input_quant = scheme_dict.get("input_activations")
|
||||||
format = scheme_dict.get("format")
|
format = scheme_dict.get("format")
|
||||||
|
|
||||||
|
if quant_config._is_mxfp4(weight_quant):
|
||||||
|
return CompressedTensorsW4A4Mxfp4MoEMethod(layer.moe_config)
|
||||||
|
|
||||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
# group_size=None means channelwise
|
# group_size=None means channelwise
|
||||||
group_size = weight_quant.group_size or -1
|
group_size = weight_quant.group_size or -1
|
||||||
@@ -224,6 +229,140 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
def __init__(self, moe):
|
||||||
|
super().__init__(moe)
|
||||||
|
self.group_size = 32
|
||||||
|
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
|
||||||
|
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.params_dtype = params_dtype
|
||||||
|
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
hidden_size // 2,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
intermediate_size_per_partition // 2,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
hidden_size // self.group_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
||||||
|
)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
intermediate_size_per_partition // self.group_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
return make_mxfp4_moe_quant_config(
|
||||||
|
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
layer.w13_weight_packed.data, requires_grad=False
|
||||||
|
)
|
||||||
|
delattr(layer, "w13_weight_packed")
|
||||||
|
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
layer.w2_weight_packed.data, requires_grad=False
|
||||||
|
)
|
||||||
|
delattr(layer, "w2_weight_packed")
|
||||||
|
|
||||||
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
|
|
||||||
|
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||||
|
if self.moe_quant_config is not None:
|
||||||
|
self.kernel = make_nvfp4_moe_kernel(
|
||||||
|
backend=self.mxfp4_backend,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
moe_config=self.moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: FusedMoE,
|
||||||
|
router: FusedMoERouter,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x_routing, _ = x
|
||||||
|
else:
|
||||||
|
x_routing = x
|
||||||
|
|
||||||
|
topk_weights, topk_ids = router.select_experts(
|
||||||
|
hidden_states=x_routing,
|
||||||
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
|
assert self.kernel is not None
|
||||||
|
return self.kernel(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation=layer.activation,
|
||||||
|
global_num_experts=layer.global_num_experts,
|
||||||
|
expert_map=layer.expert_map,
|
||||||
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user