From 827268e98d92761e9189c15baec6a452cf3ac945 Mon Sep 17 00:00:00 2001 From: PikaPikachu Date: Fri, 10 Apr 2026 01:24:43 +0800 Subject: [PATCH] [Quantization] Support Quark W8A8 INT8 MoE inference (#36320) Signed-off-by: kangletian --- tests/quantization/test_quark.py | 31 ++ vllm/model_executor/layers/fused_moe/utils.py | 11 +- .../layers/quantization/quark/quark.py | 38 +++ .../layers/quantization/quark/quark_moe.py | 282 ++++++++++++++++++ 4 files changed, 360 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index afb0437f5..c51b3461c 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -22,6 +22,9 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 QuarkW8A8Fp8, QuarkW8A8Int8, ) +from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 + QuarkW8A8Int8MoEMethod, +) from vllm.platforms import current_platform from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch @@ -126,6 +129,34 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): assert output +@pytest.mark.parametrize("tp", [1]) +def test_quark_int8_w8a8_moe(vllm_runner, tp): + """Test W8A8 INT8 MoE quantization with a tiny Qwen3 MoE model.""" + model_path = "nameistoken/tiny-qwen3-moe-w8a8-int8-quark" + with vllm_runner( + model_path, + enforce_eager=True, + tensor_parallel_size=tp, + gpu_memory_utilization=0.1, + ) as llm: + + def check_model(model): + layer = model.model.layers[0] + # MoE experts should use QuarkW8A8Int8MoEMethod + moe = layer.mlp.experts + assert isinstance(moe.quant_method, QuarkW8A8Int8MoEMethod), ( + f"Expected QuarkW8A8Int8MoEMethod, got {type(moe.quant_method)}" + ) + # Non-MoE linear layers should use QuarkW8A8Int8 + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.scheme, QuarkW8A8Int8) + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello", max_tokens=4) + assert output + + def test_quark_fp8_parity(vllm_runner): quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method" fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method" diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index c576b0a25..ce1e49bc4 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -163,8 +163,15 @@ def _int8_quantize( # activations apply per-token quantization. Otherwise, assume # activation tensor-wise fp8/int8 quantization, dynamic or static if block_shape is None: - assert per_act_token, "int8 quantization only supports block or channel-wise" - A, A_scale = per_token_quant_int8(A) + if per_act_token: + A, A_scale = per_token_quant_int8(A) + elif A_scale is not None: + # Static per-tensor: use the optimized CUDA kernel + A, A_scale, _ = ops.scaled_int8_quant(A, scale=A_scale) + elif A_scale is None: + # Dynamic per-tensor: compute scale then quantize via kernel + A_scale = torch.clamp(A.abs().max() / 127.0, min=1e-10) + A, A_scale, _ = ops.scaled_int8_quant(A, scale=A_scale) else: assert not per_act_token assert len(block_shape) == 2 diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index d0362cedc..33bd0cfc2 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -389,6 +389,37 @@ class QuarkConfig(QuantizationConfig): return is_weight_mxfp4 and is_input_fp8 + def _is_dynamic_per_token_w8a8( + self, + weight_quant: dict[str, Any] | None, + input_quant: dict[str, Any] | None, + ) -> bool: + """Detect W8A8 INT8 with per-tensor or per-channel + weights and dynamic per-token input.""" + if weight_quant is None or input_quant is None: + return False + + is_int8_dtype = ( + weight_quant.get("dtype") == "int8" and input_quant.get("dtype") == "int8" + ) + + is_valid_weight_scheme = weight_quant.get("qscheme") in [ + "per_tensor", + "per_channel", + ] + is_per_token_input = input_quant.get("qscheme") == "per_channel" + + is_dynamic_input = input_quant.get("is_dynamic") is True + is_weight_symmetric = weight_quant.get("symmetric") is True + + return ( + is_int8_dtype + and is_valid_weight_scheme + and is_per_token_input + and is_dynamic_input + and is_weight_symmetric + ) + def _is_w_ocp_mx_a_x( self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None ) -> bool: @@ -556,6 +587,13 @@ class QuarkConfig(QuantizationConfig): ) if is_w4a8_supported: return QuarkW4A8_MXFP4_FP8(weight_config, input_config) + elif self._is_dynamic_per_token_w8a8(weight_config, input_config): + weight_qscheme = cast(str, weight_config.get("qscheme")) + return QuarkW8A8Int8( + qscheme=weight_qscheme, + is_static_input_scheme=False, + input_symmetric=input_config.get("symmetric"), + ) elif self._is_w_ocp_mx_a_x(weight_config, input_config): return QuarkOCP_MX( weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 3f7ddbfd7..58ed8940b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -109,6 +109,12 @@ class QuarkMoEMethod(FusedMoEMethodBase): return QuarkOCP_MX_MoEMethod( weight_config, input_config, module.moe_config ) + elif quant_config._is_static_tensor_w8a8( + weight_config, input_config + ) or quant_config._is_dynamic_per_token_w8a8(weight_config, input_config): + return QuarkW8A8Int8MoEMethod( + weight_config, input_config, module.moe_config + ) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -505,6 +511,282 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ) +class QuarkW8A8Int8MoEMethod(QuarkMoEMethod): + """Quark W8A8 INT8 MoE method.""" + + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.weight_quant = weight_config + self.input_quant = input_config + self.weight_qscheme = self.weight_quant.get("qscheme", "per_tensor") + self.static_input_scales = not self.input_quant.get("is_dynamic", False) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.weight_qscheme == "per_channel": + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + else: + # per-tensor: one scalar per expert + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + # ZERO POINTS (loaded but discarded after loading; kernel uses symmetric) + w13_input_zero_point = torch.nn.Parameter( + torch.zeros(num_experts, 2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w13_input_zero_point", w13_input_zero_point) + set_weight_attrs(w13_input_zero_point, extra_weight_attrs) + + w2_input_zero_point = torch.nn.Parameter( + torch.zeros(num_experts, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_input_zero_point", w2_input_zero_point) + set_weight_attrs(w2_input_zero_point, extra_weight_attrs) + + if self.weight_qscheme == "per_channel": + w13_weight_zero_point = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight_zero_point = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.int8), + requires_grad=False, + ) + else: + w13_weight_zero_point = torch.nn.Parameter( + torch.zeros(num_experts, 2, dtype=torch.int8), + requires_grad=False, + ) + w2_weight_zero_point = torch.nn.Parameter( + torch.zeros(num_experts, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w13_weight_zero_point", w13_weight_zero_point) + set_weight_attrs(w13_weight_zero_point, extra_weight_attrs) + layer.register_parameter("w2_weight_zero_point", w2_weight_zero_point) + set_weight_attrs(w2_weight_zero_point, extra_weight_attrs) + + # BIAS + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + layer.w13_bias, layer.w2_bias = None, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Discard zero points (INT8 fused MoE kernel uses symmetric quant) + for attr in ( + "w13_input_zero_point", + "w2_input_zero_point", + "w13_weight_zero_point", + "w2_weight_zero_point", + ): + if hasattr(layer, attr): + delattr(layer, attr) + + # For static input scales, collapse per-expert scales to single max + if self.static_input_scales: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning_once( + "Found input_scales that are not equal for " + "INT8 MoE layer. Using the maximum across experts " + "for each layer." + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + # For per-tensor weights, merge w1/w3 scales into single per-expert + if self.weight_qscheme == "per_tensor": + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _, _ = ( + ops.scaled_int8_quant( + dq_weight, + scale=max_w13_scales[expert_id], + ) + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + is_dynamic = not self.static_input_scales + is_per_channel = self.weight_qscheme == "per_channel" + return FusedMoEQuantConfig.make( + torch.int8, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + per_act_token_quant=is_dynamic, + per_out_ch_quant=is_per_channel, + block_shape=None, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + from vllm.model_executor.layers.fused_moe import fused_experts + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=not self.moe.disable_inplace, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + quant_config=self.moe_quant_config, + ) + + class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): def __init__( self,