diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 03d25a9b1..7cf7b76d6 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -57,8 +57,8 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] -- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod] -- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] +- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod] +- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py deleted file mode 100644 index bce63bcbe..000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ /dev/null @@ -1,2541 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -from enum import Enum - -import torch -from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import ( - ActivationOrdering, - QuantizationArgs, - QuantizationStrategy, -) - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import ( - FusedMoE, - FusedMoEActivationFormat, - FusedMoEExpertsModular, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, - UnquantizedFusedMoEMethod, -) -from vllm.model_executor.layers.fused_moe.activation import MoEActivation -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEQuantConfig, - int4_w4a16_moe_quant_config, - int4_w4afp8_moe_quant_config, - int8_w8a8_moe_quant_config, - int8_w8a16_moe_quant_config, -) -from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - BatchedMarlinExperts, - MarlinExperts, - fused_marlin_moe, -) -from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - convert_to_fp8_moe_kernel_format, - make_fp8_moe_kernel, - make_fp8_moe_quant_config, - select_fp8_moe_backend, -) -from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( - Mxfp4MoeBackend, - make_mxfp4_moe_kernel, - make_mxfp4_moe_quant_config, -) -from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - convert_to_nvfp4_moe_kernel_format, - is_global_sf_supported_for_nvfp4_backend, - make_nvfp4_moe_kernel, - make_nvfp4_moe_quant_config, - select_nvfp4_moe_backend, -) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa - WNA16_SUPPORTED_BITS, - WNA16_SUPPORTED_TYPES_MAP, -) -from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( - flashinfer_trtllm_mxint4_moe, - is_flashinfer_mxint4_moe_available, - prepare_static_weights_for_trtllm_mxint4_moe, -) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - process_fp8_input_tensor_strategy_moe, - process_fp8_weight_tensor_strategy_moe, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_moe_marlin_supports_layer, - get_marlin_input_dtype, - marlin_act_int8_process_scales, - marlin_make_workspace_new, - marlin_moe_permute_scales, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - convert_bf16_scales_to_fp8, - convert_packed_uint4b8_to_signed_int4_inplace, - kFp8Dynamic128Sym, - kFp8DynamicTokenSym, - kFp8Static128BlockSym, - kFp8StaticChannelSym, - kFp8StaticTensorSym, - kNvfp4Dynamic, - kNvfp4Static, -) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - normalize_e4m3fn_to_e4m3fnuz, -) -from vllm.model_executor.utils import replace_parameter, set_weight_attrs -from vllm.platforms import CpuArchEnum, current_platform - -logger = init_logger(__name__) - - -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -__all__ = [ - "CompressedTensorsMoEMethod", - "CompressedTensorsW8A8Fp8MoEMethod", - "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4Nvfp4MoEMethod", - "CompressedTensorsW4A8Int8MoEMethod", -] - - -class CompressedTensorsMoEMethod(FusedMoEMethodBase): - @staticmethod - def get_moe_method( - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module, - layer_name: str, - ) -> FusedMoEMethodBase: - # FusedMoE was made by combining multiple Linears so need to - # make sure quantization config for Linear can target it - quant_config._add_fused_moe_to_target_scheme_map() - unfused_names = [ - layer_name + proj_name - for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] - ] - # TODO: refactor this to use expert_mapping and check all layer numbers - all_scheme_dicts = [ - quant_config.get_scheme_dict(layer, name) for name in unfused_names - ] - scheme_dict = all_scheme_dicts.pop() - - # multiple schemes found - if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]): - raise ValueError( - "All MoE projections need to have same " - "quantization scheme but found multiple" - ) - - if scheme_dict is None: # ignored layer - return UnquantizedFusedMoEMethod(layer.moe_config) - - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - weight_quant = scheme_dict.get("weights") - input_quant = scheme_dict.get("input_activations") - format = scheme_dict.get("format") - - if quant_config._is_mxfp4(weight_quant): - return CompressedTensorsW4A4Mxfp4MoEMethod(layer.moe_config) - - if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - # group_size=None means channelwise - group_size = weight_quant.group_size or -1 - - valid_format_and_bits = ( - weight_quant.num_bits in WNA16_SUPPORTED_BITS - and format == CompressionFormat.pack_quantized.value - ) - - if not valid_format_and_bits: - raise ValueError( - "For Fused MoE layers, only format: ", - f"{CompressionFormat.pack_quantized.value} ", - f" and bits: {WNA16_SUPPORTED_BITS} is supported ", - f"but got format: {CompressionFormat.pack_quantized.value} " - f" and bits: {weight_quant.num_bits}", - ) - - # Prefer to use the MarlinMoE kernel when it is supported. - if ( - not check_moe_marlin_supports_layer(layer, group_size) - or current_platform.is_rocm() - ): - if ( - weight_quant.strategy == QuantizationStrategy.GROUP - and weight_quant.actorder - in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC) - ): - raise ValueError( - "WNA16MoE is not supported with actorder=group/dynamic." - ) - logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod( - weight_quant, input_quant, layer.moe_config - ) - else: - logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") - return CompressedTensorsWNA16MarlinMoEMethod( - weight_quant, input_quant, layer.moe_config - ) - elif quant_config._is_nvfp4_format(weight_quant): - _is_valid_nvfp4_activations = ( - quant_config._is_nvfp4_format(input_quant) or input_quant is None - ) - if not _is_valid_nvfp4_activations: - raise ValueError( - "For NVFP4 weights, input quantization must also be NVFP4 format ", - f"or None for NVFP4A16, found {input_quant}", - ) - return CompressedTensorsW4A4Nvfp4MoEMethod( - layer.moe_config, layer_name, use_a16=(input_quant is None) - ) - elif ( - quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) - or quant_config._is_fp8_w8a8(weight_quant, input_quant) - ): - return CompressedTensorsW8A8Fp8MoEMethod( - weight_quant, input_quant, layer.moe_config - ) - elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod( - weight_quant, input_quant, layer.moe_config - ) - elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant): - logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod") - return CompressedTensorsW4A8Fp8MoEMethod( - weight_quant, input_quant, layer.moe_config - ) - elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): - return CompressedTensorsW4A8Int8MoEMethod( - weight_quant, input_quant, layer.moe_config - ) - else: - raise RuntimeError( - f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" - ) - - -class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): - def __init__(self, moe): - super().__init__(moe) - self.group_size = 32 - self.mxfp4_backend = Mxfp4MoeBackend.MARLIN - self.experts_cls = MarlinExperts - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - layer.num_experts = num_experts - layer.params_dtype = params_dtype - - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - # 2 fp4 items are packed in the input dimension - hidden_size // 2, - requires_grad=False, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // 2, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - # 2 fp4 items are packed in the input dimension - hidden_size // self.group_size, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // self.group_size, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - return make_mxfp4_moe_quant_config( - mxfp4_backend=self.mxfp4_backend, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ) - - def process_weights_after_loading(self, layer: FusedMoE) -> None: - layer.w13_weight = torch.nn.Parameter( - layer.w13_weight_packed.data, requires_grad=False - ) - delattr(layer, "w13_weight_packed") - - layer.w2_weight = torch.nn.Parameter( - layer.w2_weight_packed.data, requires_grad=False - ) - delattr(layer, "w2_weight_packed") - - logger.warning_once( - "Your GPU does not have native support for FP4 computation but " - "FP4 quantization is being used. Weight-only FP4 compression " - "will be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads." - ) - prepare_moe_fp4_layer_for_marlin(layer) - - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config is not None: - self.moe_kernel = make_mxfp4_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - experts_cls=self.experts_cls, - mxfp4_backend=self.mxfp4_backend, - shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), - ) - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - assert self.moe_kernel is not None - return self.moe_kernel.apply( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) - - -class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, - moe: FusedMoEConfig, - layer_name: str | None = None, - use_a16: bool = False, - ): - super().__init__(moe) - self.group_size = 16 - - # Select experts implementation. - self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( - config=self.moe, - weight_key=kNvfp4Static, - activation_key=None if use_a16 else kNvfp4Dynamic, - ) - - self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( - self.nvfp4_backend - ) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - layer.num_experts = num_experts - layer.params_dtype = params_dtype - w13_num_shards = 2 if self.moe.is_act_and_mul else 1 - - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w13_num_shards * intermediate_size_per_partition, - # 2 fp4 items are packed in the input dimension - hidden_size // 2, - requires_grad=False, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // 2, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # Weight Scales - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - w13_num_shards * intermediate_size_per_partition, - # 2 fp4 items are packed in the input dimension - hidden_size // self.group_size, - dtype=torch.float8_e4m3fn, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # Weight Global Scales - w13_weight_scale_2 = torch.nn.Parameter( - torch.empty(num_experts, w13_num_shards, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) - - w2_weight_scale_2 = torch.nn.Parameter( - torch.empty(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) - - # Input Global Scales - w13_input_scale = torch.nn.Parameter( - torch.empty(num_experts, w13_num_shards, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_input_global_scale", w13_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.empty(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_input_global_scale", w2_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - def process_weights_after_loading(self, layer: FusedMoE) -> None: - """ - Convert NVFP4 MoE weights into kernel format and setup the kernel. - """ - # NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod - # requires this naming convention. However, the name change breaks - # reloading because the state dict no longer matches disk. Once we - # remove MKM, we should revert this change to ensure compatibility. - layer.w13_weight = torch.nn.Parameter( - layer.w13_weight_packed.data, requires_grad=False - ) - delattr(layer, "w13_weight_packed") - - layer.w2_weight = torch.nn.Parameter( - layer.w2_weight_packed.data, requires_grad=False - ) - delattr(layer, "w2_weight_packed") - - # Use a single gscale for w13. - if self.moe.is_act_and_mul and not torch.allclose( - layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] - ): - logger.warning_once( - "w1_weight_global_scale must match w3_weight_global_scale. " - "Accuracy may be affected.", - ) - w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous() - - # Shuffle weights into the NvFp4 kernel format. - ( - w13, - w13_scale, - w13_scale_2, - a13_scale, - w2, - w2_scale, - w2_scale_2, - a2_scale, - ) = convert_to_nvfp4_moe_kernel_format( - nvfp4_backend=self.nvfp4_backend, - layer=layer, - w13=layer.w13_weight, - w13_scale=layer.w13_weight_scale, - w13_scale_2=(1.0 / w13_weight_global_scale), - a13_scale=(1.0 / layer.w13_input_global_scale), - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - w2_scale_2=(1.0 / layer.w2_weight_global_scale), - a2_scale=(1.0 / layer.w2_input_global_scale), - is_act_and_mul=self.moe.is_act_and_mul, - ) - - replace_parameter(layer, "w13_weight", w13) - replace_parameter(layer, "w13_weight_scale", w13_scale) - replace_parameter(layer, "w2_weight", w2) - replace_parameter(layer, "w2_weight_scale", w2_scale) - layer.w13_weight_scale_2 = w13_scale_2 - layer.w2_weight_scale_2 = w2_scale_2 - layer.w13_input_scale = a13_scale - layer.w2_input_scale = a2_scale - - # Setup modular kernel. - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - assert self.experts_cls is not None - self.moe_kernel = make_nvfp4_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - experts_cls=self.experts_cls, - shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), - ) - self.moe_kernel.fused_experts.process_weights_after_loading(layer) - - def maybe_make_prepare_finalize( - self, - routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalizeModular | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: - return make_nvfp4_moe_quant_config( - backend=self.nvfp4_backend, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w13_scale_2=layer.w13_weight_scale_2, - w2_scale_2=layer.w2_weight_scale_2, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - assert self.is_monolithic - assert self.moe_kernel is not None - return self.moe_kernel.apply_monolithic( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - e_score_correction_bias=layer.e_score_correction_bias, - routed_scaling_factor=layer.routed_scaling_factor, - ) - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - assert self.moe_kernel is not None - return self.moe_kernel.apply( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) - - -class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): - """W8A8 FP8 MoE quantization using compressed tensors.""" - - def __init__( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs, - moe: FusedMoEConfig, - layer_name: str | None = None, - ): - super().__init__(moe) - self.weight_quant = weight_quant - self.input_quant = input_quant - - per_tensor = ( - self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy == QuantizationStrategy.TENSOR - ) - per_channel = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN - ) - if not (per_tensor or per_channel): - assert self.weight_quant.strategy == QuantizationStrategy.BLOCK - self.weight_block_size = self.weight_quant.block_structure - assert self.weight_quant.dynamic is not None - else: - self.weight_block_size = None - self.block_quant = self.weight_block_size is not None - - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales and per_channel: - raise ValueError( - "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization." - ) - - ct2vllm_weight = { - QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, - QuantizationStrategy.TENSOR: kFp8StaticTensorSym, - QuantizationStrategy.BLOCK: kFp8Static128BlockSym, - } - ct2vllm_act = { - QuantizationStrategy.TOKEN: kFp8DynamicTokenSym, - QuantizationStrategy.TENSOR: ( - kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym - ), - } - weight_key = ct2vllm_weight[self.weight_quant.strategy] - if weight_key == kFp8Static128BlockSym: - activation_key = kFp8Dynamic128Sym - else: - activation_key = ct2vllm_act[self.input_quant.strategy] - - # Select Fp8 MoE backend - self.fp8_backend, self.experts_cls = select_fp8_moe_backend( - config=self.moe, - weight_key=weight_key, - activation_key=activation_key, - allow_vllm_cutlass=True, - ) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - layer.num_experts = num_experts - layer.orig_dtype = params_dtype - layer.weight_block_size = None - - params_dtype = torch.float8_e4m3fn - w13_num_shards = 2 if self.moe.is_act_and_mul else 1 - - if self.block_quant: - assert self.weight_block_size is not None - layer.weight_block_size = self.weight_block_size - tp_size = get_tensor_model_parallel_world_size() - block_n, block_k = ( - self.weight_block_size[0], - self.weight_block_size[1], - ) - # NOTE: To ensure proper alignment of the block-wise quantization - # scales, the output_size of the weights for both the gate and up - # layers must be divisible by block_n. - # Required by column parallel or enabling merged weights - if intermediate_size_per_partition % block_n != 0: - raise ValueError( - f"The output_size of gate's and up's weight = " - f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_n = {block_n}." - ) - if tp_size > 1 and intermediate_size_per_partition % block_k != 0: - # Required by row parallel - raise ValueError( - f"The input_size of down's weight = " - f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}." - ) - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w13_num_shards * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - # For gated MoE, allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - # For non-gated MoE, allocate 1 scale for w13. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, w13_num_shards, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-TENSOR quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - w13_num_shards * intermediate_size_per_partition, - 1, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - w13_num_shards - * ((intermediate_size_per_partition + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - (hidden_size + block_n - 1) // block_n, - (intermediate_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.static_input_scales: - w13_input_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: FusedMoE) -> None: - # Allow for accessing weights and scales in standard way. - w13 = layer.w13_weight - w2 = layer.w2_weight - w13_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale - w13_input_scale = layer.w13_input_scale - w2_input_scale = layer.w2_input_scale - - # MI300x and MI325x use FNUZ format for FP8. Convert if needed. - if current_platform.is_fp8_fnuz(): - w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( - w13, w13_scale, w13_input_scale - ) - w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( - w2, w2_scale, w2_input_scale - ) - - # Per tensor kernels require single activation scale. Use the max. - if self.static_input_scales: - assert self.input_quant.strategy == QuantizationStrategy.TENSOR - assert w13_input_scale is not None and w2_input_scale is not None - w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( - w13_input_scale, w2_input_scale - ) - replace_parameter(layer, "w13_input_scale", w13_input_scale) - replace_parameter(layer, "w2_input_scale", w2_input_scale) - - # Per-tensor kernels use a single scale, for W13, but on disk there - # is a separate scale for W1 and W3. Requantize with the max scale. - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - w13, w13_scale = process_fp8_weight_tensor_strategy_moe( - w13, - w13_scale, - shard_size=layer.intermediate_size_per_partition, - num_experts=layer.local_num_experts, - is_act_and_mul=self.moe.is_act_and_mul, - ) - - w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( - fp8_backend=self.fp8_backend, - layer=layer, - w13=w13, - w2=w2, - w13_scale=w13_scale, - w2_scale=w2_scale, - w13_input_scale=w13_input_scale, - w2_input_scale=w2_input_scale, - ) - - # Replace parameters with updated versions. Note that this helper - # function ensures the replacement is compatible with RL weight reloads. - replace_parameter(layer, "w13_weight", w13) - replace_parameter(layer, "w2_weight", w2) - replace_parameter(layer, "w13_weight_scale", w13_scale) - replace_parameter(layer, "w2_weight_scale", w2_scale) - - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_kernel = make_fp8_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - fp8_backend=self.fp8_backend, - experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), - shared_experts=layer.shared_experts, - ) - - def maybe_make_prepare_finalize( - self, - routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalizeModular | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: - is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - return make_fp8_moe_quant_config( - fp8_backend=self.fp8_backend, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=is_per_token, - per_out_ch_quant=is_per_token, - block_shape=self.weight_block_size, - ) - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - assert self.moe_kernel is not None - return self.moe_kernel.apply_monolithic( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - e_score_correction_bias=layer.e_score_correction_bias, - routed_scaling_factor=layer.routed_scaling_factor, - ) - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - assert not self.is_monolithic - assert self.moe_kernel is not None - return self.moe_kernel.apply( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - # TODO(rob): investigate the disable_expert_map introduced by: - # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) - - @property - def supports_eplb(self) -> bool: - return True - - -class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs, - moe: FusedMoEConfig, - layer_name: str | None = None, - ): - super().__init__(moe) - self.weight_quant = weight_quant - self.input_quant = input_quant - - per_channel = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN - ) - if not per_channel: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}" - ) - - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found static input scales." - ) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - params_dtype = torch.int8 - w13_num_shards = 2 if self.moe.is_act_and_mul else 1 - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w13_num_shards * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - w13_num_shards * intermediate_size_per_partition, - 1, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - assert not self.static_input_scales - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - return int8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=True, - ) - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=not self.moe.disable_inplace, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - ) - - -class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs | None, - moe: FusedMoEConfig, - layer_name: str | None = None, - ): - super().__init__(moe) - self.weight_quant = weight_quant - self.input_quant = input_quant - assert weight_quant.symmetric, ( - "Only symmetric quantization is supported for MoE" - ) - # Extract properties from weight_quant - self.num_bits = weight_quant.num_bits - self.packed_factor = 32 // weight_quant.num_bits - self.strategy = weight_quant.strategy - self.group_size = weight_quant.group_size - self.actorder = weight_quant.actorder - - self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] - - self.marlin_input_dtype = get_marlin_input_dtype(layer_name) - self.use_flashinfer_mxint4_moe = ( - is_flashinfer_mxint4_moe_available() - and self.group_size == 32 - and weight_quant.num_bits == 4 - ) - self.kernel_backend = ( - "Flashinfer" if self.use_flashinfer_mxint4_moe else "Marlin" - ) - logger.info_once( - f"Using {self.kernel_backend} backend for WNA16 MoE " - f"(group_size={self.group_size}, num_bits={self.num_bits})", - scope="local", - ) - - def get_weight_shape( - self, - weight_name: str, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - num_groups_w2: int | None = None, - num_groups_w13: int | None = None, - ) -> tuple[int, int, int]: - """ - Get the shape of the weight based on the weight name, number of experts - hidden size, intermediate size per partition, number of groups for w2, - and number of groups for w13. Pass in num_groups_w2 and num_groups_w13 - for weight scales. - """ - if weight_name == "w13_scale": - assert num_groups_w13 is not None, ( - "num_groups_w13 must be provided for weight scales" - ) - if weight_name == "w2_scale": - assert num_groups_w2 is not None, ( - "num_groups_w2 must be provided for weight scales" - ) - w13_num_shards = 2 if self.moe.is_act_and_mul else 1 - shape_map = { - "w13_weight": { - "Flashinfer": ( - num_experts, - w13_num_shards * intermediate_size_per_partition, - hidden_size // self.packed_factor, - ), - "Marlin": ( - num_experts, - hidden_size // self.packed_factor, - w13_num_shards * intermediate_size_per_partition, - ), - }, - "w13_scale": { - "Flashinfer": ( - num_experts, - w13_num_shards * intermediate_size_per_partition, - num_groups_w13, - ), - "Marlin": ( - num_experts, - num_groups_w13, - w13_num_shards * intermediate_size_per_partition, - ), - }, - "w2_weight": { - "Flashinfer": ( - num_experts, - hidden_size, - intermediate_size_per_partition // self.packed_factor, - ), - "Marlin": ( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - ), - }, - "w2_scale": { - "Flashinfer": (num_experts, hidden_size, num_groups_w2), - "Marlin": (num_experts, num_groups_w2, hidden_size), - }, - } - return shape_map[weight_name][self.kernel_backend] - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") - - # Will transpose the loaded weight along the - # intermediate and hidden dim sizes. Will - # shard for TP along the transposed dims - is_transposed = self.kernel_backend != "Flashinfer" - extra_weight_attrs.update( - {"is_transposed": is_transposed, "quant_method": self.strategy} - ) - - w13_weight = torch.nn.Parameter( - torch.empty( - *self.get_weight_shape( - "w13_weight", - num_experts, - hidden_size, - intermediate_size_per_partition, - ), - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - *self.get_weight_shape( - "w2_weight", - num_experts, - hidden_size, - intermediate_size_per_partition, - ), - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # In the case where we have actorder/g_idx, - # we do not partition the w2 scales - load_full_w2 = self.actorder and self.group_size != -1 - w2_scales_size = ( - intermediate_size_full if load_full_w2 else intermediate_size_per_partition - ) - - self.is_k_full = (not self.actorder) or ( - intermediate_size_per_partition == intermediate_size_full - ) - - if self.strategy == "channel": - num_groups_w2 = num_groups_w13 = 1 - self.group_size = -1 - else: - num_groups_w2 = w2_scales_size // self.group_size - num_groups_w13 = hidden_size // self.group_size - - layer.num_groups_w13 = num_groups_w13 - layer.num_groups_w2 = num_groups_w2 - - w13_scale = torch.nn.Parameter( - torch.ones( - *self.get_weight_shape( - "w13_scale", - num_experts, - hidden_size, - intermediate_size_per_partition, - num_groups_w13=num_groups_w13, - ), - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_scale) - set_weight_attrs(w13_scale, extra_weight_attrs) - - w2_scale = torch.nn.Parameter( - torch.ones( - *self.get_weight_shape( - "w2_scale", - num_experts, - hidden_size, - intermediate_size_per_partition, - num_groups_w2=num_groups_w2, - ), - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_scale) - set_weight_attrs(w2_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) - - w2_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - layer.register_parameter("w2_weight_shape", w2_weight_shape) - set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - - layer.register_parameter("w13_weight_shape", w13_weight_shape) - set_weight_attrs(w13_weight_shape, extra_weight_attrs) - - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - - layer.a13_scale = None - layer.a2_scale = None - layer.marlin_state = GPTQMarlinState.REPACK - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - num_experts = layer.w13_weight_g_idx.shape[0] - device = layer.w13_weight_g_idx.device - if self.kernel_backend == "Flashinfer": - dict_weights_mxint4 = prepare_static_weights_for_trtllm_mxint4_moe( - layer.w13_weight_packed, - layer.w13_weight_scale, - layer.w2_weight_packed, - layer.w2_weight_scale, - ) - replace_parameter( - layer, "w13_weight_packed", dict_weights_mxint4["gemm1_weights"] - ) - replace_parameter( - layer, "w13_weight_scale", dict_weights_mxint4["gemm1_scales"] - ) - replace_parameter( - layer, "w2_weight_packed", dict_weights_mxint4["gemm2_weights"] - ) - replace_parameter( - layer, "w2_weight_scale", dict_weights_mxint4["gemm2_scales"] - ) - return None - - is_a_8bit = ( - self.marlin_input_dtype is not None - and self.marlin_input_dtype.itemsize == 1 - ) - - if self.marlin_input_dtype == torch.float8_e4m3fn: - # NOTE: for non-zp quantization format only - ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True) - ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True) - layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512 - layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512 - - # when running models with grouped act order, - # resort to g_idx values provided in checkpoint - if self.actorder == "group": - w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) - - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( - torch.int32 - ) - w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( - torch.int32 - ) - w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ - w13_g_idx_sort_indices[e] - ] - w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] - - replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) - replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) - - else: - layer.w13_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_weight_packed, - layer.w13_g_idx_sort_indices, - layer.w13_weight_packed.shape[1] * self.packed_factor, - layer.w13_weight_packed.shape[2], - self.num_bits, - is_a_8bit=is_a_8bit, - ) - replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) - - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_weight_packed, - layer.w2_g_idx_sort_indices, - layer.w2_weight_packed.shape[1] * self.packed_factor, - layer.w2_weight_packed.shape[2], - self.num_bits, - is_a_8bit=is_a_8bit, - ) - replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) - - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - s=layer.w13_weight_scale, - size_k=layer.w13_weight_packed.shape[2], - size_n=layer.w13_weight_scale.shape[2], - group_size=self.group_size, - is_a_8bit=is_a_8bit, - ) - if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1: - marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales( - marlin_w13_scales - ) - layer.register_parameter( - "w13_input_global_scale", - torch.nn.Parameter(w13_input_global_scale, requires_grad=False), - ) - replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) - - marlin_w2_scales = marlin_moe_permute_scales( - s=layer.w2_weight_scale, - size_k=layer.w2_weight_scale.shape[1] - * (self.group_size if self.group_size != -1 else self.packed_factor), - size_n=layer.w2_weight_scale.shape[2], - group_size=self.group_size, - is_a_8bit=is_a_8bit, - ) - if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1: - marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales( - marlin_w2_scales - ) - layer.register_parameter( - "w2_input_global_scale", - torch.nn.Parameter(w2_input_global_scale, requires_grad=False), - ) - replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) - - layer.workspace = marlin_make_workspace_new(device, 4) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - if self.num_bits != 4: - return None - return int4_w4a16_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, self.group_size], - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, - layer: torch.nn.Module, - ) -> mk.FusedMoEExpertsModular: - assert self.num_bits == 4, "only supporting w4" - layer.w13_weight = layer.w13_weight_packed - layer.w2_weight = layer.w2_weight_packed - assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) - assert self.moe_quant_config is not None - if ( - prepare_finalize.activation_format - == mk.FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - return BatchedMarlinExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - moe_config=self.moe, - quant_config=self.moe_quant_config, - w13_g_idx=layer.w13_weight_g_idx, - w2_g_idx=layer.w2_weight_g_idx, - w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, - w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, - is_k_full=self.is_k_full, - ) - else: - return MarlinExperts( - moe_config=self.moe, - quant_config=self.moe_quant_config, - w13_g_idx=layer.w13_weight_g_idx, - w2_g_idx=layer.w2_weight_g_idx, - w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, - w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, - is_k_full=self.is_k_full, - ) - - @property - def is_monolithic(self) -> bool: - return self.kernel_backend == "Flashinfer" - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - assert self.kernel_backend == "Flashinfer" - return flashinfer_trtllm_mxint4_moe( - x=x, - router_logits=router_logits, - w13_weight_packed=layer.w13_weight_packed, - w13_weight_scale=layer.w13_weight_scale, - w2_weight_packed=layer.w2_weight_packed, - w2_weight_scale=layer.w2_weight_scale, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - intermediate_size_per_partition=layer.intermediate_size_per_partition, - local_num_experts=layer.local_num_experts, - ep_rank=layer.ep_rank, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - e_score_correction_bias=layer.e_score_correction_bias, - routing_method_type=layer.routing_method_type, - ) - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - assert self.kernel_backend == "Marlin" - return fused_marlin_moe( - x, - layer.w13_weight_packed, - layer.w2_weight_packed, - None, - None, - layer.w13_weight_scale, - layer.w2_weight_scale, - topk_weights, - topk_ids, - input_global_scale1=getattr(layer, "w13_input_global_scale", None), - input_global_scale2=getattr(layer, "w2_input_global_scale", None), - quant_type_id=self.quant_type.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - activation=layer.activation, - expert_map=layer.expert_map, - g_idx1=layer.w13_weight_g_idx, - g_idx2=layer.w2_weight_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - workspace=layer.workspace, - input_dtype=self.marlin_input_dtype, - is_k_full=self.is_k_full, - inplace=not self.moe.disable_inplace, - ) - - -class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs | None, - moe: FusedMoEConfig, - layer_name: str | None = None, - ): - super().__init__(moe) - self.weight_quant = weight_quant - self.input_quant = input_quant - # Extract properties from weight_quant - self.num_bits = weight_quant.num_bits - self.packed_factor = 32 // weight_quant.num_bits - self.strategy = weight_quant.strategy - # channelwise is not supported by this kernel - assert weight_quant.strategy == "group" - self.group_size = weight_quant.group_size - # grouped actorder isn't supported by this kernel - assert weight_quant.actorder != "group" - assert weight_quant.symmetric, ( - "Only symmetric quantization is supported for MoE" - ) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Will transpose the loaded weight along the - # intermediate and hidden dim sizes. Will - # shard for TP along the transposed dims - extra_weight_attrs.update( - {"is_transposed": True, "quant_method": self.strategy} - ) - w13_num_shards = 2 if self.moe.is_act_and_mul else 1 - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size // self.packed_factor, - w13_num_shards * intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - w2_scales_size = intermediate_size_per_partition - - if self.strategy == "channel": - num_groups_w2 = num_groups_w13 = 1 - self.group_size = -1 - else: - num_groups_w2 = w2_scales_size // self.group_size - num_groups_w13 = hidden_size // self.group_size - - w13_scale = torch.nn.Parameter( - torch.ones( - num_experts, - num_groups_w13, - w13_num_shards * intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_scale) - set_weight_attrs(w13_scale, extra_weight_attrs) - - w2_scale = torch.nn.Parameter( - torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_scale) - set_weight_attrs(w2_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, {"load_full_w2": False}) - - w2_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - layer.register_parameter("w2_weight_shape", w2_weight_shape) - set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - - layer.register_parameter("w13_weight_shape", w13_weight_shape) - set_weight_attrs(w13_weight_shape, extra_weight_attrs) - - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - - layer.a13_scale = None - layer.a2_scale = None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Reconfigure packed weights and scales to match moe_wna16 format - layer.w13_weight_packed = torch.nn.Parameter( - layer.w13_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), - requires_grad=False, - ) - layer.w2_weight_packed = torch.nn.Parameter( - layer.w2_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), - requires_grad=False, - ) - layer.w13_weight_scale = torch.nn.Parameter( - layer.w13_weight_scale.transpose(1, 2).contiguous(), requires_grad=False - ) - layer.w2_weight_scale = torch.nn.Parameter( - layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - assert self.num_bits == 4 or self.num_bits == 8 - config_builder = ( - int4_w4a16_moe_quant_config - if self.num_bits == 4 - else int8_w8a16_moe_quant_config - ) - - return config_builder( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, self.group_size], - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, - layer: torch.nn.Module, - ) -> mk.FusedMoEExpertsModular: - if self.moe.is_lora_enabled: - assert self.moe_quant_config is not None - from vllm.triton_utils import HAS_TRITON - - if HAS_TRITON: - from vllm.model_executor.layers.fused_moe import TritonWNA16Experts - - layer.w13_weight = layer.w13_weight_packed - layer.w2_weight = layer.w2_weight_packed - return TritonWNA16Experts( - moe_config=self.moe, quant_config=self.moe_quant_config - ) - else: - raise NotImplementedError( - "TritonExperts requires Triton. " - "Install triton or disable LoRA for MoE." - ) - - raise NotImplementedError - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - - return fused_experts( - x, - layer.w13_weight_packed, - layer.w2_weight_packed, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=not self.moe.disable_inplace, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - ) - - @property - def supports_eplb(self) -> bool: - return True - - -class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): - """ - CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform - - Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles) - - Scales: Fp32 for Channelwise , bf16 for groupwise quantization - - Bias: Same data type as original weights - - Activations: FP32/Bf16 dynamic per-token (A8 Int), - quantized inside the kernel - """ - - def __init__( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs, - moe: FusedMoEConfig, - layer_name: str | None = None, - ): - super().__init__(moe) - self.has_bias = self.moe.has_bias - self.weight_quant = weight_quant - self.input_quant = input_quant - - # Validate scheme: weights=W4 (channel or group), - # activations=dynamic TOKEN (A8) - - # Must be dynamic per-token activations - if ( - input_quant.strategy != QuantizationStrategy.TOKEN - or not input_quant.dynamic - ): - raise ValueError( - "W4A8-int MoE needs dynamic per-token activation quantization." - ) - - # Weight can be channel-wise (group_size=None) or group-wise - self.group_size = ( - weight_quant.group_size if (weight_quant.group_size is not None) else -1 - ) - if weight_quant.num_bits != 4: - raise ValueError("This method only supports 4-bit weights (num_bits=4).") - - # CPU only - if not current_platform.is_cpu(): - raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.") - - # Arm: check _dyn ops availability - if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: - try: - _ = torch.ops.aten._dyn_quant_matmul_4bit - _ = torch.ops.aten._dyn_quant_pack_4bit_weight - except AttributeError as err: - raise RuntimeError( - f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops; - install a newer build.""" - ) from err - self.static_input_scales = False # always dynamic per token - - # ---- parameter creation ---- - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Shapes per local rank (TP/EP): - # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) - # w2 : [E, H, I_local] int8 - # Scales: - # channel-wise: group_size=-1 -> per-output-row, single scale per row - # group-wise : group_size=g -> - # per-output-row, (in_features/g) scales - - E = num_experts - H = hidden_size - IN = intermediate_size_per_partition - g = self.group_size - - # Per-row scale columns - def _n_scale_cols(in_features: int) -> int: - return 1 if g == -1 else (in_features // g) - - # Register unpacked int4-as-int8 weights the loader will fill. - w13 = torch.nn.Parameter( - torch.empty(E, 2 * IN, H, dtype=torch.int8), requires_grad=False - ) - set_weight_attrs(w13, extra_weight_attrs) - layer.register_parameter("w13_weight", w13) - - w2 = torch.nn.Parameter( - torch.empty(E, H, IN, dtype=torch.int8), requires_grad=False - ) - set_weight_attrs(w2, extra_weight_attrs) - layer.register_parameter("w2_weight", w2) - - # Register scales - # KleidiAI groupwise kernels accepts float32 scales - # KleidiAI groupwise kernels accepts bfloat16 scales - scale_dtype = torch.float32 if g == -1 else torch.bfloat16 - - w13_s = torch.nn.Parameter( - torch.ones(E, 2 * IN, _n_scale_cols(H), dtype=scale_dtype), - requires_grad=False, - ) - set_weight_attrs( - w13_s, - {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, - ) - layer.register_parameter("w13_weight_scale", w13_s) - - w2_s = torch.nn.Parameter( - torch.ones(E, H, _n_scale_cols(IN), dtype=scale_dtype), requires_grad=False - ) - set_weight_attrs( - w2_s, - {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, - ) - layer.register_parameter("w2_weight_scale", w2_s) - - if self.has_bias: - w13_bias = torch.nn.Parameter( - torch.zeros(E, 2 * IN, dtype=params_dtype), requires_grad=False - ) - layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - - w2_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) - - # Placeholders for packed weights (will be replaced after packing) - layer.register_parameter( - "w13_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) - ) - set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs) - - layer.register_parameter( - "w2_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) - ) - set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs) - - # dims for 4 bit fused matmuls - layer.w13_in_features = H - layer.w13_out_features = 2 * IN - layer.w2_in_features = IN - layer.w2_out_features = H - layer.group_size = g - - # post-load packing to dyn-4bit KleidiAI kernel's format - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - E = layer.w13_weight.shape[0] - H = layer.w13_in_features - I2 = layer.w13_out_features - IN = layer.w2_in_features - g = layer.group_size - - def _pack_matrix( - int4_as_int8_2d: torch.Tensor, - scales_2d: torch.Tensor, - bias_1d: torch.Tensor | None, - in_features: int, - out_features: int, - ) -> torch.Tensor: - # int4 values are stored as int8 in [-8,7]. - # Shift to unsigned nibble and pack pairs along input-dim. - tmp = int4_as_int8_2d.add(8) # [out, in] - uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to( - torch.uint8 - ) # [out, in//2] - - # KleidiAI groupwise kernels accepts float32 scales - # KleidiAI groupwise kernels accepts bfloat16 scales - scale_dtype = torch.float32 if g == -1 else torch.bfloat16 - scales = scales_2d.to(scale_dtype) - bias = None if bias_1d is None else bias_1d.to(torch.float32) - return torch.ops.aten._dyn_quant_pack_4bit_weight( - uint8_nibbles, - scales, - bias, - g if g != -1 else in_features, - in_features, - out_features, - ) - - # Pack per expert - w13_packed_list = [] - w2_packed_list = [] - - has_w13_bias = hasattr(layer, "w13_bias") and layer.w13_bias is not None - has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None - - for e in range(E): - w13_packed_list.append( - _pack_matrix( - layer.w13_weight[e], # [2I, H] - layer.w13_weight_scale[e], # [2I, H/g or 1] - layer.w13_bias[e] if has_w13_bias else None, # [2I] - H, - I2, - ) - ) - w2_packed_list.append( - _pack_matrix( - # w2 shape is [H, IN]; we need [out, in] == [H, IN]. - layer.w2_weight[e], # [H, IN] - layer.w2_weight_scale[e], # [H, IN/g or 1] - layer.w2_bias[e] if has_w2_bias else None, # [H] - IN, - layer.w2_out_features, # in_features=IN, out_features=H - ) - ) - - # each packed tensor has identical shape per expert; stack on dim 0 - w13_packed = torch.stack(w13_packed_list, dim=0) - w2_packed = torch.stack(w2_packed_list, dim=0) - - replace_parameter( - layer, - "w13_weight_packed", - torch.nn.Parameter(w13_packed, requires_grad=False), - ) - replace_parameter( - layer, - "w2_weight_packed", - torch.nn.Parameter(w2_packed, requires_grad=False), - ) - - # free raw tensors/scales/bias now that they're packed into the payload. - replace_parameter( - layer, "w13_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) - ) - replace_parameter( - layer, "w2_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) - ) - replace_parameter( - layer, - "w13_weight_scale", - torch.nn.Parameter(torch.empty(0), requires_grad=False), - ) - replace_parameter( - layer, - "w2_weight_scale", - torch.nn.Parameter(torch.empty(0), requires_grad=False), - ) - if has_w13_bias: - replace_parameter( - layer, - "w13_bias", - torch.nn.Parameter(torch.empty(0), requires_grad=False), - ) - if has_w2_bias: - replace_parameter( - layer, - "w2_bias", - torch.nn.Parameter(torch.empty(0), requires_grad=False), - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - # CPU dynamic 4-bit MoE path does not use modular kernels or - # fused_experts; quant config is not needed. - return None - - @property - def is_monolithic(self) -> bool: - return True - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." - assert layer.activation in ( - MoEActivation.SILU, - MoEActivation.SWIGLUOAI, - MoEActivation.SWIGLUSTEP, - ), "Only SiLU/SwiGLUGU/SwiGLUUG are supported." - assert layer.expert_map is None, """expert_map/EP not implemented - for CPU dyn-4bit MoE.""" - - def _act_kind(s: MoEActivation) -> int: - # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU - if s == MoEActivation.SWIGLUSTEP: - return 0 - if s == MoEActivation.SWIGLUOAI: - return 1 - if s == MoEActivation.SILU: - return 2 - raise ValueError(f"Unknown activation '{s}'") - - # Apply topk softmax on router output - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=layer.top_k, - use_grouped_topk=layer.use_grouped_topk, - renormalize=layer.renormalize, - ) - - return torch.ops._C.dynamic_4bit_int_moe( - x, - topk_ids.to(torch.long), - topk_weights, - layer.w13_weight_packed, - layer.w2_weight_packed, - layer.w2_out_features, - layer.w2_in_features, - layer.w13_out_features, - layer.group_size, - layer.apply_router_weight_on_input, - int(_act_kind(layer.activation)), - ) - - -class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs, - moe: FusedMoEConfig, - layer_name: str | None = None, - ): - super().__init__(moe) - self.weight_quant = weight_quant - self.input_quant = input_quant - - self.group_size = self.weight_quant.group_size - self.num_bits = self.weight_quant.num_bits - self.packed_factor = 32 // self.num_bits - - assert self.weight_quant.symmetric, ( - "Only symmetric quantization is supported for W4A8 MoE" - ) - assert self.weight_quant.actorder != "group" - assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE" - - self.disable_expert_map = False - self.layer_name = layer_name - - from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, - ) - - self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - layer.num_experts = num_experts - layer.orig_dtype = params_dtype - layer.weight_block_size = None - - # requirement for CUTLASS reorder_tensor - assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256" - assert intermediate_size_per_partition % 256 == 0, ( - f"{intermediate_size_per_partition=} must be divisible by 256" - ) - # storage type, pack 8xint4 into int32 - params_dtype = torch.int32 - - # WEIGHTS - w13_weight_packed = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.packed_factor, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight_packed) - set_weight_attrs(w13_weight_packed, extra_weight_attrs) - - w2_weight_packed = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.packed_factor, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight_packed) - set_weight_attrs(w2_weight_packed, extra_weight_attrs) - - # SCALES - # weight_scale refers to the group-wise scales - # they are initially loaded as bf16, we will convert to fp8 - # after loading - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=layer.orig_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=layer.orig_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-GROUP quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # weight shapes - w2_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - layer.register_parameter("w2_weight_shape", w2_weight_shape) - set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - layer.register_parameter("w13_weight_shape", w13_weight_shape) - set_weight_attrs(w13_weight_shape, extra_weight_attrs) - - # don't use input scales - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer): - device = layer.w13_weight_packed.device - - # STRIDES - # A, C - self.a_strides1_c_strides2 = torch.full( - (layer.local_num_experts,), - layer.hidden_size, - device=device, - dtype=torch.int64, - ) - self.a_strides2 = torch.full( - (layer.local_num_experts,), - layer.intermediate_size_per_partition, - device=device, - dtype=torch.int64, - ) - self.c_strides1 = torch.full( - (layer.local_num_experts,), - 2 * layer.intermediate_size_per_partition, - device=device, - dtype=torch.int64, - ) - - # S (group-wise scales) - # sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it - self.s_strides1 = torch.zeros( - (layer.local_num_experts, 2), device=device, dtype=torch.int64 - ) - self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition - - self.s_strides2 = torch.zeros( - (layer.local_num_experts, 2), device=device, dtype=torch.int64 - ) - self.s_strides2[:, 0] = layer.hidden_size - - # encode and reorder weight tensors, and get the layout to pass to - # the grouped gemm kernel. `b_strides1/2` specifies the entire layout - convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed) - w13_weight_shuffled, self.b_strides1 = ( - ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed) - ) - replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled) - convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed) - w2_weight_shuffled, self.b_strides2 = ( - ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed) - ) - replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled) - - # convert bf16 scales to (fp8_scales, channel_scales) - w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8( - self.quant_fp8, layer.w13_weight_scale - ) - w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8( - self.quant_fp8, layer.w2_weight_scale - ) - - # register channel scales - layer.register_parameter( - "w13_weight_chan_scale", - torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False), - ) - layer.register_parameter( - "w2_weight_chan_scale", - torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False), - ) - - # The scales are stored as (E, N, K // 128) but the kernel expects - # (E, K // 128, N) in row-major format, so we need to permute the last 2 dims - # and make it contiguous - w13_weight_scale_packed = ops.cutlass_pack_scale_fp8( - w13_weight_scale.permute(0, 2, 1).contiguous() - ) - replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed) - w2_weight_scale_packed = ops.cutlass_pack_scale_fp8( - w2_weight_scale.permute(0, 2, 1).contiguous() - ) - replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed) - - def maybe_make_prepare_finalize( - self, - routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalizeModular | None: - return super().maybe_make_prepare_finalize(routing_tables) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - # Store quantization scales; both per-group and per-channel - # Note we haven't specified the group size here because - # the quant config logic assumes group-wise scaling - # and channel-wise scaling are exclusive. - return int4_w4afp8_moe_quant_config( - w1_scale=layer.w13_weight_scale, # group scale - w2_scale=layer.w2_weight_scale, # group scale - g1_alphas=layer.w13_weight_chan_scale, - g2_alphas=layer.w2_weight_chan_scale, - per_act_token_quant=True, # always use dynamic per-token - per_out_ch_quant=True, # always use per-channel - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, - layer: torch.nn.Module, - ) -> mk.FusedMoEExpertsModular: - assert self.moe_quant_config is not None - assert ( - prepare_finalize.activation_format == FusedMoEActivationFormat.Standard - ), "BatchedExperts not supported" - - from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 - - experts: FusedMoEExpertsModular - - logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) - experts = CutlassExpertsW4A8Fp8( - out_dtype=self.moe.in_dtype, - a_strides1=self.a_strides1_c_strides2, - a_strides2=self.a_strides2, - b_strides1=self.b_strides1, - b_strides2=self.b_strides2, - c_strides1=self.c_strides1, - c_strides2=self.a_strides1_c_strides2, - s_strides1=self.s_strides1, - s_strides2=self.s_strides2, - moe_config=self.moe, - quant_config=self.moe_quant_config, - group_size=self.group_size, - ) - - num_dispatchers = prepare_finalize.num_dispatchers() - self.disable_expert_map = ( - num_dispatchers > 1 or not experts.supports_expert_map() - ) - - return experts - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor: - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." - ) - assert self.moe_quant_config is not None - - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_w4a8_fp8, - ) - - return cutlass_moe_w4a8_fp8( - x, - layer.w13_weight_packed, - layer.w2_weight_packed, - topk_weights, - topk_ids, - moe_config=self.moe, - quant_config=self.moe_quant_config, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=None if self.disable_expert_map else layer.expert_map, - a_strides1=self.a_strides1_c_strides2, - a_strides2=self.a_strides2, - b_strides1=self.b_strides1, - b_strides2=self.b_strides2, - c_strides1=self.c_strides1, - c_strides2=self.a_strides1_c_strides2, - s_strides1=self.s_strides1, - s_strides2=self.s_strides2, - group_size=self.group_size, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - - @property - def supports_eplb(self) -> bool: - return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/__init__.py new file mode 100644 index 000000000..39c9113f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod, +) + +__all__ = [ + "CompressedTensorsMoEMethod", +] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe.py new file mode 100644 index 000000000..9ee8df9da --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from compressed_tensors import CompressionFormat +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationStrategy, +) + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoEMethodBase, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_moe_marlin_supports_layer, +) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + layer: torch.nn.Module, + layer_name: str, + ) -> FusedMoEMethodBase: + # FusedMoE was made by combining multiple Linears so need to + # make sure quantization config for Linear can target it + quant_config._add_fused_moe_to_target_scheme_map() + unfused_names = [ + layer_name + proj_name + for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] + ] + # TODO: refactor this to use expert_mapping and check all layer numbers + all_scheme_dicts = [ + quant_config.get_scheme_dict(layer, name) for name in unfused_names + ] + scheme_dict = all_scheme_dicts.pop() + + # multiple schemes found + if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]): + raise ValueError( + "All MoE projections need to have same " + "quantization scheme but found multiple" + ) + + if scheme_dict is None: # ignored layer + return UnquantizedFusedMoEMethod(layer.moe_config) + + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") + + if quant_config._is_mxfp4(weight_quant): + from .compressed_tensors_moe_w4a4_mxfp4 import ( + CompressedTensorsW4A4Mxfp4MoEMethod, + ) + + return CompressedTensorsW4A4Mxfp4MoEMethod(layer.moe_config) + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + # group_size=None means channelwise + group_size = weight_quant.group_size or -1 + + valid_format_and_bits = ( + weight_quant.num_bits in WNA16_SUPPORTED_BITS + and format == CompressionFormat.pack_quantized.value + ) + + if not valid_format_and_bits: + raise ValueError( + "For Fused MoE layers, only format: ", + f"{CompressionFormat.pack_quantized.value} ", + f" and bits: {WNA16_SUPPORTED_BITS} is supported ", + f"but got format: {CompressionFormat.pack_quantized.value} " + f" and bits: {weight_quant.num_bits}", + ) + + # Prefer to use the MarlinMoE kernel when it is supported. + if ( + not check_moe_marlin_supports_layer(layer, group_size) + or current_platform.is_rocm() + ): + from .compressed_tensors_moe_wna16 import ( + CompressedTensorsWNA16MoEMethod, + ) + + if ( + weight_quant.strategy == QuantizationStrategy.GROUP + and weight_quant.actorder + in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC) + ): + raise ValueError( + "WNA16MoE is not supported with actorder=group/dynamic." + ) + logger.info_once("Using CompressedTensorsWNA16MoEMethod") + return CompressedTensorsWNA16MoEMethod( + weight_quant, input_quant, layer.moe_config + ) + else: + from .compressed_tensors_moe_wna16_marlin import ( + CompressedTensorsWNA16MarlinMoEMethod, + ) + + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") + return CompressedTensorsWNA16MarlinMoEMethod( + weight_quant, input_quant, layer.moe_config + ) + elif quant_config._is_nvfp4_format(weight_quant): + from .compressed_tensors_moe_w4a4_nvfp4 import ( + CompressedTensorsW4A4Nvfp4MoEMethod, + ) + + _is_valid_nvfp4_activations = ( + quant_config._is_nvfp4_format(input_quant) or input_quant is None + ) + if not _is_valid_nvfp4_activations: + raise ValueError( + "For NVFP4 weights, input quantization must also be NVFP4 format ", + f"or None for NVFP4A16, found {input_quant}", + ) + return CompressedTensorsW4A4Nvfp4MoEMethod( + layer.moe_config, layer_name, use_a16=(input_quant is None) + ) + elif ( + quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) + or quant_config._is_fp8_w8a8(weight_quant, input_quant) + ): + from .compressed_tensors_moe_w8a8_fp8 import ( + CompressedTensorsW8A8Fp8MoEMethod, + ) + + return CompressedTensorsW8A8Fp8MoEMethod( + weight_quant, input_quant, layer.moe_config + ) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + from .compressed_tensors_moe_w8a8_int8 import ( + CompressedTensorsW8A8Int8MoEMethod, + ) + + return CompressedTensorsW8A8Int8MoEMethod( + weight_quant, input_quant, layer.moe_config + ) + elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant): + from .compressed_tensors_moe_w4a8_fp8 import ( + CompressedTensorsW4A8Fp8MoEMethod, + ) + + logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod") + return CompressedTensorsW4A8Fp8MoEMethod( + weight_quant, input_quant, layer.moe_config + ) + elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): + from .compressed_tensors_moe_w4a8_int8 import ( + CompressedTensorsW4A8Int8MoEMethod, + ) + + return CompressedTensorsW4A8Int8MoEMethod( + weight_quant, input_quant, layer.moe_config + ) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py new file mode 100644 index 000000000..8cc6d17bc --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, +) +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + make_mxfp4_moe_kernel, + make_mxfp4_moe_quant_config, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, +) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): + def __init__(self, moe): + super().__init__(moe) + self.group_size = 32 + self.mxfp4_backend = Mxfp4MoeBackend.MARLIN + self.experts_cls = MarlinExperts + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return make_mxfp4_moe_quant_config( + mxfp4_backend=self.mxfp4_backend, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + + def process_weights_after_loading(self, layer: FusedMoE) -> None: + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) + delattr(layer, "w13_weight_packed") + + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) + delattr(layer, "w2_weight_packed") + + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression " + "will be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + prepare_moe_fp4_layer_for_marlin(layer) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config is not None: + self.moe_kernel = make_mxfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + mxfp4_backend=self.mxfp4_backend, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py new file mode 100644 index 000000000..09a216fd2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + convert_to_nvfp4_moe_kernel_format, + is_global_sf_supported_for_nvfp4_backend, + make_nvfp4_moe_kernel, + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs + +logger = init_logger(__name__) + + +class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + moe: FusedMoEConfig, + layer_name: str | None = None, + use_a16: bool = False, + ): + super().__init__(moe) + self.group_size = 16 + + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=None if use_a16 else kNvfp4Dynamic, + ) + + self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( + self.nvfp4_backend + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.params_dtype = params_dtype + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer: FusedMoE) -> None: + """ + Convert NVFP4 MoE weights into kernel format and setup the kernel. + """ + # NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod + # requires this naming convention. However, the name change breaks + # reloading because the state dict no longer matches disk. Once we + # remove MKM, we should revert this change to ensure compatibility. + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) + delattr(layer, "w13_weight_packed") + + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) + delattr(layer, "w2_weight_packed") + + # Use a single gscale for w13. + if self.moe.is_act_and_mul and not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected.", + ) + w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous() + + # Shuffle weights into the NvFp4 kernel format. + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_scale_2=(1.0 / w13_weight_global_scale), + a13_scale=(1.0 / layer.w13_input_global_scale), + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_scale_2=(1.0 / layer.w2_weight_global_scale), + a2_scale=(1.0 / layer.w2_input_global_scale), + is_act_and_mul=self.moe.is_act_and_mul, + ) + + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w2_weight_scale", w2_scale) + layer.w13_weight_scale_2 = w13_scale_2 + layer.w2_weight_scale_2 = w2_scale_2 + layer.w13_input_scale = a13_scale + layer.w2_input_scale = a2_scale + + # Setup modular kernel. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + assert self.experts_cls is not None + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), + ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: + return make_nvfp4_moe_quant_config( + backend=self.nvfp4_backend, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w13_scale_2=layer.w13_weight_scale_2, + w2_scale_2=layer.w2_weight_scale_2, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py new file mode 100644 index 000000000..74cb0b4f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_fp8.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, +) + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEActivationFormat, + FusedMoEExpertsModular, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + int4_w4afp8_moe_quant_config, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + convert_bf16_scales_to_fp8, + convert_packed_uint4b8_to_signed_int4_inplace, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs + +logger = init_logger(__name__) + + +class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + + self.group_size = self.weight_quant.group_size + self.num_bits = self.weight_quant.num_bits + self.packed_factor = 32 // self.num_bits + + assert self.weight_quant.symmetric, ( + "Only symmetric quantization is supported for W4A8 MoE" + ) + assert self.weight_quant.actorder != "group" + assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE" + + self.disable_expert_map = False + self.layer_name = layer_name + + from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # requirement for CUTLASS reorder_tensor + assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256" + assert intermediate_size_per_partition % 256 == 0, ( + f"{intermediate_size_per_partition=} must be divisible by 256" + ) + # storage type, pack 8xint4 into int32 + params_dtype = torch.int32 + + # WEIGHTS + w13_weight_packed = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.packed_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight_packed) + set_weight_attrs(w13_weight_packed, extra_weight_attrs) + + w2_weight_packed = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight_packed) + set_weight_attrs(w2_weight_packed, extra_weight_attrs) + + # SCALES + # weight_scale refers to the group-wise scales + # they are initially loaded as bf16, we will convert to fp8 + # after loading + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-GROUP quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # weight shapes + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + # don't use input scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer): + device = layer.w13_weight_packed.device + + # STRIDES + # A, C + self.a_strides1_c_strides2 = torch.full( + (layer.local_num_experts,), + layer.hidden_size, + device=device, + dtype=torch.int64, + ) + self.a_strides2 = torch.full( + (layer.local_num_experts,), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + self.c_strides1 = torch.full( + (layer.local_num_experts,), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + + # S (group-wise scales) + # sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it + self.s_strides1 = torch.zeros( + (layer.local_num_experts, 2), device=device, dtype=torch.int64 + ) + self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition + + self.s_strides2 = torch.zeros( + (layer.local_num_experts, 2), device=device, dtype=torch.int64 + ) + self.s_strides2[:, 0] = layer.hidden_size + + # encode and reorder weight tensors, and get the layout to pass to + # the grouped gemm kernel. `b_strides1/2` specifies the entire layout + convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed) + w13_weight_shuffled, self.b_strides1 = ( + ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed) + ) + replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled) + convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed) + w2_weight_shuffled, self.b_strides2 = ( + ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed) + ) + replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled) + + # convert bf16 scales to (fp8_scales, channel_scales) + w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8( + self.quant_fp8, layer.w13_weight_scale + ) + w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8( + self.quant_fp8, layer.w2_weight_scale + ) + + # register channel scales + layer.register_parameter( + "w13_weight_chan_scale", + torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False), + ) + layer.register_parameter( + "w2_weight_chan_scale", + torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False), + ) + + # The scales are stored as (E, N, K // 128) but the kernel expects + # (E, K // 128, N) in row-major format, so we need to permute the last 2 dims + # and make it contiguous + w13_weight_scale_packed = ops.cutlass_pack_scale_fp8( + w13_weight_scale.permute(0, 2, 1).contiguous() + ) + replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed) + w2_weight_scale_packed = ops.cutlass_pack_scale_fp8( + w2_weight_scale.permute(0, 2, 1).contiguous() + ) + replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + return super().maybe_make_prepare_finalize(routing_tables) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + # Store quantization scales; both per-group and per-channel + # Note we haven't specified the group size here because + # the quant config logic assumes group-wise scaling + # and channel-wise scaling are exclusive. + return int4_w4afp8_moe_quant_config( + w1_scale=layer.w13_weight_scale, # group scale + w2_scale=layer.w2_weight_scale, # group scale + g1_alphas=layer.w13_weight_chan_scale, + g2_alphas=layer.w2_weight_chan_scale, + per_act_token_quant=True, # always use dynamic per-token + per_out_ch_quant=True, # always use per-channel + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, + layer: torch.nn.Module, + ) -> mk.FusedMoEExpertsModular: + assert self.moe_quant_config is not None + assert ( + prepare_finalize.activation_format == FusedMoEActivationFormat.Standard + ), "BatchedExperts not supported" + + from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 + + experts: FusedMoEExpertsModular + + logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) + experts = CutlassExpertsW4A8Fp8( + out_dtype=self.moe.in_dtype, + a_strides1=self.a_strides1_c_strides2, + a_strides2=self.a_strides2, + b_strides1=self.b_strides1, + b_strides2=self.b_strides2, + c_strides1=self.c_strides1, + c_strides2=self.a_strides1_c_strides2, + s_strides1=self.s_strides1, + s_strides2=self.s_strides2, + moe_config=self.moe, + quant_config=self.moe_quant_config, + group_size=self.group_size, + ) + + num_dispatchers = prepare_finalize.num_dispatchers() + self.disable_expert_map = ( + num_dispatchers > 1 or not experts.supports_expert_map() + ) + + return experts + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." + ) + assert self.moe_quant_config is not None + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_w4a8_fp8, + ) + + return cutlass_moe_w4a8_fp8( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights, + topk_ids, + moe_config=self.moe, + quant_config=self.moe_quant_config, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=None if self.disable_expert_map else layer.expert_map, + a_strides1=self.a_strides1_c_strides2, + a_strides2=self.a_strides2, + b_strides1=self.b_strides1, + b_strides2=self.b_strides2, + c_strides1=self.c_strides1, + c_strides2=self.a_strides1_c_strides2, + s_strides1=self.s_strides1, + s_strides2=self.s_strides2, + group_size=self.group_size, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) + + @property + def supports_eplb(self) -> bool: + return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py new file mode 100644 index 000000000..2e8a935ca --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, +) + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, +) +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.platforms import CpuArchEnum, current_platform + +logger = init_logger(__name__) + + +class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): + """ + CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform + - Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles) + - Scales: Fp32 for Channelwise , bf16 for groupwise quantization + - Bias: Same data type as original weights + - Activations: FP32/Bf16 dynamic per-token (A8 Int), + quantized inside the kernel + """ + + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.weight_quant = weight_quant + self.input_quant = input_quant + + # Validate scheme: weights=W4 (channel or group), + # activations=dynamic TOKEN (A8) + + # Must be dynamic per-token activations + if ( + input_quant.strategy != QuantizationStrategy.TOKEN + or not input_quant.dynamic + ): + raise ValueError( + "W4A8-int MoE needs dynamic per-token activation quantization." + ) + + # Weight can be channel-wise (group_size=None) or group-wise + self.group_size = ( + weight_quant.group_size if (weight_quant.group_size is not None) else -1 + ) + if weight_quant.num_bits != 4: + raise ValueError("This method only supports 4-bit weights (num_bits=4).") + + # CPU only + if not current_platform.is_cpu(): + raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.") + + # Arm: check _dyn ops availability + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + try: + _ = torch.ops.aten._dyn_quant_matmul_4bit + _ = torch.ops.aten._dyn_quant_pack_4bit_weight + except AttributeError as err: + raise RuntimeError( + f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops; + install a newer build.""" + ) from err + self.static_input_scales = False # always dynamic per token + + # ---- parameter creation ---- + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Shapes per local rank (TP/EP): + # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) + # w2 : [E, H, I_local] int8 + # Scales: + # channel-wise: group_size=-1 -> per-output-row, single scale per row + # group-wise : group_size=g -> + # per-output-row, (in_features/g) scales + + E = num_experts + H = hidden_size + IN = intermediate_size_per_partition + g = self.group_size + + # Per-row scale columns + def _n_scale_cols(in_features: int) -> int: + return 1 if g == -1 else (in_features // g) + + # Register unpacked int4-as-int8 weights the loader will fill. + w13 = torch.nn.Parameter( + torch.empty(E, 2 * IN, H, dtype=torch.int8), requires_grad=False + ) + set_weight_attrs(w13, extra_weight_attrs) + layer.register_parameter("w13_weight", w13) + + w2 = torch.nn.Parameter( + torch.empty(E, H, IN, dtype=torch.int8), requires_grad=False + ) + set_weight_attrs(w2, extra_weight_attrs) + layer.register_parameter("w2_weight", w2) + + # Register scales + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + + w13_s = torch.nn.Parameter( + torch.ones(E, 2 * IN, _n_scale_cols(H), dtype=scale_dtype), + requires_grad=False, + ) + set_weight_attrs( + w13_s, + {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, + ) + layer.register_parameter("w13_weight_scale", w13_s) + + w2_s = torch.nn.Parameter( + torch.ones(E, H, _n_scale_cols(IN), dtype=scale_dtype), requires_grad=False + ) + set_weight_attrs( + w2_s, + {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, + ) + layer.register_parameter("w2_weight_scale", w2_s) + + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros(E, 2 * IN, dtype=params_dtype), requires_grad=False + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + # Placeholders for packed weights (will be replaced after packing) + layer.register_parameter( + "w13_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs) + + layer.register_parameter( + "w2_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs) + + # dims for 4 bit fused matmuls + layer.w13_in_features = H + layer.w13_out_features = 2 * IN + layer.w2_in_features = IN + layer.w2_out_features = H + layer.group_size = g + + # post-load packing to dyn-4bit KleidiAI kernel's format + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E = layer.w13_weight.shape[0] + H = layer.w13_in_features + I2 = layer.w13_out_features + IN = layer.w2_in_features + g = layer.group_size + + def _pack_matrix( + int4_as_int8_2d: torch.Tensor, + scales_2d: torch.Tensor, + bias_1d: torch.Tensor | None, + in_features: int, + out_features: int, + ) -> torch.Tensor: + # int4 values are stored as int8 in [-8,7]. + # Shift to unsigned nibble and pack pairs along input-dim. + tmp = int4_as_int8_2d.add(8) # [out, in] + uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to( + torch.uint8 + ) # [out, in//2] + + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + scales = scales_2d.to(scale_dtype) + bias = None if bias_1d is None else bias_1d.to(torch.float32) + return torch.ops.aten._dyn_quant_pack_4bit_weight( + uint8_nibbles, + scales, + bias, + g if g != -1 else in_features, + in_features, + out_features, + ) + + # Pack per expert + w13_packed_list = [] + w2_packed_list = [] + + has_w13_bias = hasattr(layer, "w13_bias") and layer.w13_bias is not None + has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None + + for e in range(E): + w13_packed_list.append( + _pack_matrix( + layer.w13_weight[e], # [2I, H] + layer.w13_weight_scale[e], # [2I, H/g or 1] + layer.w13_bias[e] if has_w13_bias else None, # [2I] + H, + I2, + ) + ) + w2_packed_list.append( + _pack_matrix( + # w2 shape is [H, IN]; we need [out, in] == [H, IN]. + layer.w2_weight[e], # [H, IN] + layer.w2_weight_scale[e], # [H, IN/g or 1] + layer.w2_bias[e] if has_w2_bias else None, # [H] + IN, + layer.w2_out_features, # in_features=IN, out_features=H + ) + ) + + # each packed tensor has identical shape per expert; stack on dim 0 + w13_packed = torch.stack(w13_packed_list, dim=0) + w2_packed = torch.stack(w2_packed_list, dim=0) + + replace_parameter( + layer, + "w13_weight_packed", + torch.nn.Parameter(w13_packed, requires_grad=False), + ) + replace_parameter( + layer, + "w2_weight_packed", + torch.nn.Parameter(w2_packed, requires_grad=False), + ) + + # free raw tensors/scales/bias now that they're packed into the payload. + replace_parameter( + layer, "w13_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + replace_parameter( + layer, "w2_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) + replace_parameter( + layer, + "w13_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + replace_parameter( + layer, + "w2_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + if has_w13_bias: + replace_parameter( + layer, + "w13_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + if has_w2_bias: + replace_parameter( + layer, + "w2_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + # CPU dynamic 4-bit MoE path does not use modular kernels or + # fused_experts; quant config is not needed. + return None + + @property + def is_monolithic(self) -> bool: + return True + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." + assert layer.activation in ( + MoEActivation.SILU, + MoEActivation.SWIGLUOAI, + MoEActivation.SWIGLUSTEP, + ), "Only SiLU/SwiGLUGU/SwiGLUUG are supported." + assert layer.expert_map is None, """expert_map/EP not implemented + for CPU dyn-4bit MoE.""" + + def _act_kind(s: MoEActivation) -> int: + # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU + if s == MoEActivation.SWIGLUSTEP: + return 0 + if s == MoEActivation.SWIGLUOAI: + return 1 + if s == MoEActivation.SILU: + return 2 + raise ValueError(f"Unknown activation '{s}'") + + # Apply topk softmax on router output + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=layer.top_k, + use_grouped_topk=layer.use_grouped_topk, + renormalize=layer.renormalize, + ) + + return torch.ops._C.dynamic_4bit_int_moe( + x, + topk_ids.to(torch.long), + topk_weights, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w2_out_features, + layer.w2_in_features, + layer.w13_out_features, + layer.group_size, + layer.apply_router_weight_on_input, + int(_act_kind(layer.activation)), + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py new file mode 100644 index 000000000..ed8ed79c5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -0,0 +1,414 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, +) + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + convert_to_fp8_moe_kernel_format, + make_fp8_moe_kernel, + make_fp8_moe_quant_config, + select_fp8_moe_backend, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + process_fp8_input_tensor_strategy_moe, + process_fp8_weight_tensor_strategy_moe, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + """W8A8 FP8 MoE quantization using compressed tensors.""" + + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not (per_tensor or per_channel): + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization." + ) + + ct2vllm_weight = { + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, + QuantizationStrategy.BLOCK: kFp8Static128BlockSym, + } + ct2vllm_act = { + QuantizationStrategy.TOKEN: kFp8DynamicTokenSym, + QuantizationStrategy.TENSOR: ( + kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym + ), + } + weight_key = ct2vllm_weight[self.weight_quant.strategy] + if weight_key == kFp8Static128BlockSym: + activation_key = kFp8Dynamic128Sym + else: + activation_key = ct2vllm_act[self.input_quant.strategy] + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=weight_key, + activation_key=activation_key, + allow_vllm_cutlass=True, + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + params_dtype = torch.float8_e4m3fn + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # For gated MoE, allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + w13_num_shards * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + w13_num_shards + * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: FusedMoE) -> None: + # Allow for accessing weights and scales in standard way. + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_input_scale = layer.w13_input_scale + w2_input_scale = layer.w2_input_scale + + # MI300x and MI325x use FNUZ format for FP8. Convert if needed. + if current_platform.is_fp8_fnuz(): + w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( + w13, w13_scale, w13_input_scale + ) + w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + w2, w2_scale, w2_input_scale + ) + + # Per tensor kernels require single activation scale. Use the max. + if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR + assert w13_input_scale is not None and w2_input_scale is not None + w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( + w13_input_scale, w2_input_scale + ) + replace_parameter(layer, "w13_input_scale", w13_input_scale) + replace_parameter(layer, "w2_input_scale", w2_input_scale) + + # Per-tensor kernels use a single scale, for W13, but on disk there + # is a separate scale for W1 and W3. Requantize with the max scale. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + w13, w13_scale = process_fp8_weight_tensor_strategy_moe( + w13, + w13_scale, + shard_size=layer.intermediate_size_per_partition, + num_experts=layer.local_num_experts, + is_act_and_mul=self.moe.is_act_and_mul, + ) + + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) + + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + + # Setup modular kernel for TP case and naive DP/EP case. + # In non-naive DP/EP case, we will create a ModularKernelMethod. + # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel + # in both cases. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + assert self.experts_cls is not None + self.moe_kernel = make_fp8_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: + is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=is_per_token, + per_out_ch_quant=is_per_token, + block_shape=self.weight_block_size, + ) + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert not self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + # TODO(rob): investigate the disable_expert_map introduced by: + # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) + + @property + def supports_eplb(self) -> bool: + return True diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py new file mode 100644 index 000000000..de155f9e1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, +) + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + int8_w8a8_moe_quant_config, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}" + ) + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + params_dtype = torch.int8 + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + w13_num_shards * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return int8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=not self.moe.disable_inplace, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + quant_config=self.moe_quant_config, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py new file mode 100644 index 000000000..f530a1a1d --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, +) + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + # Extract properties from weight_quant + self.num_bits = weight_quant.num_bits + self.packed_factor = 32 // weight_quant.num_bits + self.strategy = weight_quant.strategy + # channelwise is not supported by this kernel + assert weight_quant.strategy == "group" + self.group_size = weight_quant.group_size + # grouped actorder isn't supported by this kernel + assert weight_quant.actorder != "group" + assert weight_quant.symmetric, ( + "Only symmetric quantization is supported for MoE" + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + w13_num_shards * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_scales_size = intermediate_size_per_partition + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + w13_num_shards * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": False}) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Reconfigure packed weights and scales to match moe_wna16 format + layer.w13_weight_packed = torch.nn.Parameter( + layer.w13_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), + requires_grad=False, + ) + layer.w2_weight_packed = torch.nn.Parameter( + layer.w2_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), + requires_grad=False, + ) + layer.w13_weight_scale = torch.nn.Parameter( + layer.w13_weight_scale.transpose(1, 2).contiguous(), requires_grad=False + ) + layer.w2_weight_scale = torch.nn.Parameter( + layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + assert self.num_bits == 4 or self.num_bits == 8 + config_builder = ( + int4_w4a16_moe_quant_config + if self.num_bits == 4 + else int8_w8a16_moe_quant_config + ) + + return config_builder( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size], + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, + layer: torch.nn.Module, + ) -> mk.FusedMoEExpertsModular: + if self.moe.is_lora_enabled: + assert self.moe_quant_config is not None + from vllm.triton_utils import HAS_TRITON + + if HAS_TRITON: + from vllm.model_executor.layers.fused_moe import TritonWNA16Experts + + layer.w13_weight = layer.w13_weight_packed + layer.w2_weight = layer.w2_weight_packed + return TritonWNA16Experts( + moe_config=self.moe, quant_config=self.moe_quant_config + ) + else: + raise NotImplementedError( + "TritonExperts requires Triton. " + "Install triton or disable LoRA for MoE." + ) + + raise NotImplementedError + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + return fused_experts( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=not self.moe.disable_inplace, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + quant_config=self.moe_quant_config, + ) + + @property + def supports_eplb(self) -> bool: + return True diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py new file mode 100644 index 000000000..216eed637 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py @@ -0,0 +1,575 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from enum import Enum + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, +) + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + int4_w4a16_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + fused_marlin_moe, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa E501 + CompressedTensorsMoEMethod, +) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_TYPES_MAP, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( + flashinfer_trtllm_mxint4_moe, + is_flashinfer_mxint4_moe_available, + prepare_static_weights_for_trtllm_mxint4_moe, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, + marlin_act_int8_process_scales, + marlin_make_workspace_new, + marlin_moe_permute_scales, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs + +logger = init_logger(__name__) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + assert weight_quant.symmetric, ( + "Only symmetric quantization is supported for MoE" + ) + # Extract properties from weight_quant + self.num_bits = weight_quant.num_bits + self.packed_factor = 32 // weight_quant.num_bits + self.strategy = weight_quant.strategy + self.group_size = weight_quant.group_size + self.actorder = weight_quant.actorder + + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] + + self.marlin_input_dtype = get_marlin_input_dtype(layer_name) + self.use_flashinfer_mxint4_moe = ( + is_flashinfer_mxint4_moe_available() + and self.group_size == 32 + and weight_quant.num_bits == 4 + ) + self.kernel_backend = ( + "Flashinfer" if self.use_flashinfer_mxint4_moe else "Marlin" + ) + logger.info_once( + f"Using {self.kernel_backend} backend for WNA16 MoE " + f"(group_size={self.group_size}, num_bits={self.num_bits})", + scope="local", + ) + + def get_weight_shape( + self, + weight_name: str, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + num_groups_w2: int | None = None, + num_groups_w13: int | None = None, + ) -> tuple[int, int, int]: + """ + Get the shape of the weight based on the weight name, number of experts + hidden size, intermediate size per partition, number of groups for w2, + and number of groups for w13. Pass in num_groups_w2 and num_groups_w13 + for weight scales. + """ + if weight_name == "w13_scale": + assert num_groups_w13 is not None, ( + "num_groups_w13 must be provided for weight scales" + ) + if weight_name == "w2_scale": + assert num_groups_w2 is not None, ( + "num_groups_w2 must be provided for weight scales" + ) + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + shape_map = { + "w13_weight": { + "Flashinfer": ( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size // self.packed_factor, + ), + "Marlin": ( + num_experts, + hidden_size // self.packed_factor, + w13_num_shards * intermediate_size_per_partition, + ), + }, + "w13_scale": { + "Flashinfer": ( + num_experts, + w13_num_shards * intermediate_size_per_partition, + num_groups_w13, + ), + "Marlin": ( + num_experts, + num_groups_w13, + w13_num_shards * intermediate_size_per_partition, + ), + }, + "w2_weight": { + "Flashinfer": ( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + ), + "Marlin": ( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + ), + }, + "w2_scale": { + "Flashinfer": (num_experts, hidden_size, num_groups_w2), + "Marlin": (num_experts, num_groups_w2, hidden_size), + }, + } + return shape_map[weight_name][self.kernel_backend] + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + is_transposed = self.kernel_backend != "Flashinfer" + extra_weight_attrs.update( + {"is_transposed": is_transposed, "quant_method": self.strategy} + ) + + w13_weight = torch.nn.Parameter( + torch.empty( + *self.get_weight_shape( + "w13_weight", + num_experts, + hidden_size, + intermediate_size_per_partition, + ), + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + *self.get_weight_shape( + "w2_weight", + num_experts, + hidden_size, + intermediate_size_per_partition, + ), + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + w2_scales_size = ( + intermediate_size_full if load_full_w2 else intermediate_size_per_partition + ) + + self.is_k_full = (not self.actorder) or ( + intermediate_size_per_partition == intermediate_size_full + ) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + layer.num_groups_w13 = num_groups_w13 + layer.num_groups_w2 = num_groups_w2 + + w13_scale = torch.nn.Parameter( + torch.ones( + *self.get_weight_shape( + "w13_scale", + num_experts, + hidden_size, + intermediate_size_per_partition, + num_groups_w13=num_groups_w13, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones( + *self.get_weight_shape( + "w2_scale", + num_experts, + hidden_size, + intermediate_size_per_partition, + num_groups_w2=num_groups_w2, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + if self.kernel_backend == "Flashinfer": + dict_weights_mxint4 = prepare_static_weights_for_trtllm_mxint4_moe( + layer.w13_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_packed, + layer.w2_weight_scale, + ) + replace_parameter( + layer, "w13_weight_packed", dict_weights_mxint4["gemm1_weights"] + ) + replace_parameter( + layer, "w13_weight_scale", dict_weights_mxint4["gemm1_scales"] + ) + replace_parameter( + layer, "w2_weight_packed", dict_weights_mxint4["gemm2_weights"] + ) + replace_parameter( + layer, "w2_weight_scale", dict_weights_mxint4["gemm2_scales"] + ) + return None + + is_a_8bit = ( + self.marlin_input_dtype is not None + and self.marlin_input_dtype.itemsize == 1 + ) + + if self.marlin_input_dtype == torch.float8_e4m3fn: + # NOTE: for non-zp quantization format only + ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True) + ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True) + layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512 + layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512 + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e] + ] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + is_a_8bit=is_a_8bit, + ) + replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) + + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + is_a_8bit=is_a_8bit, + ) + replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) + + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_weight_scale, + size_k=layer.w13_weight_packed.shape[2], + size_n=layer.w13_weight_scale.shape[2], + group_size=self.group_size, + is_a_8bit=is_a_8bit, + ) + if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1: + marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales( + marlin_w13_scales + ) + layer.register_parameter( + "w13_input_global_scale", + torch.nn.Parameter(w13_input_global_scale, requires_grad=False), + ) + replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_weight_scale, + size_k=layer.w2_weight_scale.shape[1] + * (self.group_size if self.group_size != -1 else self.packed_factor), + size_n=layer.w2_weight_scale.shape[2], + group_size=self.group_size, + is_a_8bit=is_a_8bit, + ) + if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1: + marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales( + marlin_w2_scales + ) + layer.register_parameter( + "w2_input_global_scale", + torch.nn.Parameter(w2_input_global_scale, requires_grad=False), + ) + replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + + layer.workspace = marlin_make_workspace_new(device, 4) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + if self.num_bits != 4: + return None + return int4_w4a16_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size], + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, + layer: torch.nn.Module, + ) -> mk.FusedMoEExpertsModular: + assert self.num_bits == 4, "only supporting w4" + layer.w13_weight = layer.w13_weight_packed + layer.w2_weight = layer.w2_weight_packed + assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) + assert self.moe_quant_config is not None + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + moe_config=self.moe, + quant_config=self.moe_quant_config, + w13_g_idx=layer.w13_weight_g_idx, + w2_g_idx=layer.w2_weight_g_idx, + w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, + w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + else: + return MarlinExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + w13_g_idx=layer.w13_weight_g_idx, + w2_g_idx=layer.w2_weight_g_idx, + w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, + w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + + @property + def is_monolithic(self) -> bool: + return self.kernel_backend == "Flashinfer" + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + assert self.kernel_backend == "Flashinfer" + return flashinfer_trtllm_mxint4_moe( + x=x, + router_logits=router_logits, + w13_weight_packed=layer.w13_weight_packed, + w13_weight_scale=layer.w13_weight_scale, + w2_weight_packed=layer.w2_weight_packed, + w2_weight_scale=layer.w2_weight_scale, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + intermediate_size_per_partition=layer.intermediate_size_per_partition, + local_num_experts=layer.local_num_experts, + ep_rank=layer.ep_rank, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routing_method_type=layer.routing_method_type, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert self.kernel_backend == "Marlin" + return fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + None, + None, + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + input_global_scale1=getattr(layer, "w13_input_global_scale", None), + input_global_scale2=getattr(layer, "w2_input_global_scale", None), + quant_type_id=self.quant_type.id, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + activation=layer.activation, + expert_map=layer.expert_map, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + workspace=layer.workspace, + input_dtype=self.marlin_input_dtype, + is_k_full=self.is_k_full, + inplace=not self.moe.disable_inplace, + )