diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 4a0c8ba4d..c09931971 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -77,6 +77,31 @@ def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): assert output +@pytest.mark.parametrize('tp', [1]) +def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp): + model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts" + with vllm_runner(model_path, tensor_parallel_size=tp) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8) + + if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): + assert qkv_proj.weight.dtype is current_platform.fp8_dtype() + assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[ + 1] + assert qkv_proj.weight_scale.shape[1] == 1 + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + + @pytest.mark.parametrize('tp', [1]) def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6cff9f301..bc8ae9804 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,17 +5,25 @@ from typing import Any, Callable, Optional, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -67,21 +75,45 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): self.weight_quant = weight_config self.input_quant = input_config - weight_qscheme = self.weight_quant.get("qscheme") - input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_tensor" - and input_qscheme == "per_tensor"): + self.weight_qscheme = self.weight_quant.get("qscheme") + self.input_qscheme = self.input_quant.get("qscheme") + per_tensor = (self.weight_qscheme == "per_tensor" + and self.input_qscheme == "per_tensor") + per_channel = (self.weight_qscheme == "per_channel" + and self.input_qscheme == "per_channel") + self.act_quant_group_shape = GroupShape.PER_TOKEN \ + if per_channel else GroupShape.PER_TENSOR + if not (per_tensor or per_channel): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}") # noqa E501 + "For FP8 Fused MoE layers, only per-tensor and per-channel " + "scales for weights and activations are supported. Found " + f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None params_dtype = torch.float8_e4m3fn # WEIGHTS @@ -104,24 +136,39 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + if self.weight_qscheme == "per_tensor": + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.weight_qscheme == "per_channel": + # quark's scale is 1 dim. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: @@ -185,24 +232,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, requires_grad=False) - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size + # For per-tensor case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_qscheme == "per_tensor": + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + # quark's scale is 1 dim. + elif self.weight_qscheme == "per_channel": + if self.act_quant_group_shape == GroupShape.PER_TOKEN: + w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False) + w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + # Property to determine if AITER is used + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts + elif self.use_marlin: + + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + self.fused_experts_func = None + else: + from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts def apply( self, @@ -233,8 +316,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -249,22 +330,60 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_qscheme == "per_channel", + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + expert_map=expert_map) + if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + None, + None, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert self.fused_experts_func is not None + + return self.fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, + activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_qscheme == "per_channel", + global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - activation=activation) + a2_scale=layer.w2_input_scale) class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):