From 727c41f3fd7c4fb183b18eb1a41a6d5be8a42c17 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Sat, 27 Dec 2025 15:22:48 -0500 Subject: [PATCH] [MoE Refactor][10/N] Cleanup Fp8 Process Weights After Loading (#31169) Signed-off-by: Robert Shaw Co-authored-by: Robert Shaw --- .../model_executor/layers/quantization/fp8.py | 344 ++++++++---------- 1 file changed, 150 insertions(+), 194 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9da19c082..3d571a77d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module -from torch.nn.parameter import Parameter from torch.utils._python_dispatch import TorchDispatchMode import vllm.envs as envs @@ -728,6 +727,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None + self.weight_scale_name = ( + "weight_scale_inv" if self.block_quant else "weight_scale" + ) self.fp8_backend = get_fp8_moe_backend( self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled ) @@ -832,38 +834,28 @@ class Fp8MoEMethod(FusedMoEMethodBase): # WEIGHT_SCALES if not self.block_quant: - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) + # For per-tensor quant, the scales are per expert and weight. + w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32) + w2_scale_data = torch.ones(num_experts, dtype=torch.float32) else: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, + # For block quant, the scales are per block (typically 128x128). + w13_scale_data = torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - (hidden_size + block_n - 1) // block_n, - (intermediate_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, + w2_scale_data = torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, ) - layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) - layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) - assert self.quant_config.activation_scheme == "dynamic" + w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False) + w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False) + # Note: name is weight_scale for tensor, weight_scale_inv for block. + layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale) + layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly @@ -877,6 +869,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # INPUT_SCALES if self.quant_config.activation_scheme == "static": + assert not self.block_quant w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -893,158 +886,60 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale = None layer.w2_input_scale = None - def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - - # TODO (rob): refactor block quant into separate class. - if self.block_quant: - assert self.quant_config.activation_scheme == "dynamic" - if current_platform.is_fp8_fnuz(): - w13_weight, w13_weight_scale_inv, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, - layer.w13_weight_scale_inv, - layer.w13_input_scale, - ) - ) - w2_weight, w2_weight_scale_inv, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale - ) - ) - elif self.flashinfer_moe_backend is not None: - # NOTE: weights have to be swapped since the activation is - # applied on different half for flashinfer vs vllm - w13_weight = swap_w13_to_w31(layer.w13_weight.data) - w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data) - w2_weight = layer.w2_weight.data - w2_weight_scale_inv = layer.w2_weight_scale_inv.data + def _convert_weights_to_kernel_format( + self, + layer: Module, + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, + w13_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + ) -> None: + if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: + assert self.block_quant + w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block( + wq=w13_weight, + ws=w13_weight_scale, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), + ) + w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block( + wq=w2_weight, + ws=w2_weight_scale, + quant_block_shape=tuple(layer.weight_block_size), + use_e8m0=is_deep_gemm_e8m0_used(), + ) + elif self.fp8_backend == Fp8MoeBackend.AITER: + w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( + w13_weight, w2_weight + ) + elif self.fp8_backend in [ + Fp8MoeBackend.FLASHINFER_CUTLASS, + Fp8MoeBackend.FLASHINFER_TRTLLM, + ]: + w13_weight = swap_w13_to_w31(w13_weight) + if self.block_quant: + w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: - w13_weight = layer.w13_weight.data - w13_weight_scale_inv = layer.w13_weight_scale_inv.data - w2_weight = layer.w2_weight - w2_weight_scale_inv = layer.w2_weight_scale_inv - - # torch.compile() cannot use Parameter subclasses. - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv) - replace_parameter(layer, "w2_weight", w2_weight) - replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv) - if self.fp8_backend == Fp8MoeBackend.AITER: - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) - - replace_parameter(layer, "w13_weight", shuffled_w13) - replace_parameter(layer, "w2_weight", shuffled_w2) - - # DeepGemm scales need to be transposed and aligned. We try to do - # it ahead of time for performance reasons. - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - dg_w13_weight, dg_w13_weight_scale_inv = ( - deepgemm_post_process_fp8_weight_block( - wq=layer.w13_weight.data, - ws=layer.w13_weight_scale_inv.data, - quant_block_shape=tuple(layer.weight_block_size), - use_e8m0=is_deep_gemm_e8m0_used(), - ) - ) - dg_w2_weight, dg_w2_weight_scale_inv = ( - deepgemm_post_process_fp8_weight_block( - wq=layer.w2_weight.data, - ws=layer.w2_weight_scale_inv.data, - quant_block_shape=tuple(layer.weight_block_size), - use_e8m0=is_deep_gemm_e8m0_used(), - ) - ) - layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False) - layer.w13_weight_scale_inv = Parameter( - dg_w13_weight_scale_inv, requires_grad=False - ) - layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = Parameter( - dg_w2_weight_scale_inv, requires_grad=False - ) - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.quant_config.activation_scheme == "static": - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer." - ) - replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max()) - replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max()) - if current_platform.is_fp8_fnuz(): - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) - ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) - ) - # Reset the parameter - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w13_weight_scale", w13_weight_scale) - if w13_input_scale is not None: - replace_parameter(layer, "w13_input_scale", w13_input_scale) - replace_parameter(layer, "w2_weight", w2_weight) - replace_parameter(layer, "w2_weight_scale", w2_weight_scale) - if w2_input_scale is not None: - replace_parameter(layer, "w2_input_scale", w2_input_scale) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - - if self.fp8_backend == Fp8MoeBackend.AITER: - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight, layer.w2_weight - ) - - replace_parameter(layer, "w13_weight", shuffled_w13) - replace_parameter(layer, "w2_weight", shuffled_w2) - - replace_parameter(layer, "w13_weight_scale", max_w13_scales) - - if self.flashinfer_moe_backend is not None: - # NOTE: weights have to be swapped since the activation is - # applied on different half for flashinfer vs vllm - assert not self.block_quant + # TODO(rob): this function is a hack that renames the scaling + # factors in the Module. This is a hack we should clean up. register_moe_scaling_factors(layer) - w13_weight = swap_w13_to_w31(layer.w13_weight.data) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) - layer.w13_weight.data = w13_weight.data + elif self.fp8_backend == Fp8MoeBackend.AITER: + w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( + w13_weight, w2_weight + ) + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale) + replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale) + + # TODO(rob): we do this after replace_parameter() because + # prepare_moe_fp8_layer_for_marlin uses on the layer's params + # directly. We will refactor this in a follow up PR. if self.fp8_backend == Fp8MoeBackend.MARLIN: prepare_moe_fp8_layer_for_marlin( layer, False, input_dtype=self.marlin_input_dtype @@ -1053,6 +948,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale + def _setup_kernel(self, layer: Module) -> None: + """Setup Modular Kernel for TP Case""" # NOTE(rob): this is a WIP refactor. We are first migrating # all of the kernels in the TP case to use mk. Once this is # done, then we will initialzie the TP case and DP/EP case @@ -1134,6 +1031,71 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) self.use_inplace = True + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + # Allow for accessing weights and scales in standard way. + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight + w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}") + w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}") + + # MI300x and MI325x use FNUZ format for FP8. Convert if needed. + if current_platform.is_fp8_fnuz(): + w13_weight, w13_weight_scale, layer.w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + w13_weight, w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, layer.w2_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + w2_weight, w2_weight_scale, layer.w2_input_scale + ) + ) + + # Per tensor kernels require single activation scale. Use the max. + if self.quant_config.activation_scheme == "static": + assert not self.block_quant + assert layer.w13_input_scale is not None + assert layer.w2_input_scale is not None + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max()) + replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max()) + + # Per tensor kernels require single weight scale for w13 per expert, but + # on disk there is a scale for w1 and w3. Use the max to requantize. + if not self.block_quant: + shard_size = layer.intermediate_size_per_partition + max_w13_scales = w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + w13_weight[expert_id][start : start + shard_size, :], + w13_weight_scale[expert_id][shard_id], + ) + w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + w13_weight_scale = max_w13_scales + + # Shuffle weights into the runtime format. + self._convert_weights_to_kernel_format( + layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale + ) + + # Setup modular kernel for TP case. + self._setup_kernel(layer) + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -1453,32 +1415,26 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) for expert in range(layer.local_num_experts): w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ops.scaled_fp8_quant(layer.w13_weight[expert, :, :]) ) w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ops.scaled_fp8_quant(layer.w2_weight[expert, :, :]) ) replace_parameter(layer, "w13_weight", w13_weight) replace_parameter(layer, "w2_weight", w2_weight) - # Reshuffle weights for AITER if needed. - if self.fp8_backend == Fp8MoeBackend.AITER: - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight, layer.w2_weight - ) - replace_parameter(layer, "w13_weight", shuffled_w13) - replace_parameter(layer, "w2_weight", shuffled_w2) + # Shuffle weights into the runtime format. + self._convert_weights_to_kernel_format( + layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale + ) - # Rushuffle weights for MARLIN if needed. - elif self.fp8_backend == Fp8MoeBackend.MARLIN: - prepare_moe_fp8_layer_for_marlin( - layer, False, input_dtype=self.marlin_input_dtype - ) + # Setup modular kernel for TP case. + self._setup_kernel(layer) class Fp8KVCacheMethod(BaseKVCacheMethod):