[MoE Refactor][10/N] Cleanup Fp8 Process Weights After Loading (#31169)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-12-27 15:22:48 -05:00
committed by GitHub
parent 2f12cd32c0
commit 727c41f3fd

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
@@ -728,6 +727,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
self.fp8_backend = get_fp8_moe_backend(
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
)
@@ -832,38 +834,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# WEIGHT_SCALES
if not self.block_quant:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# For per-tensor quant, the scales are per expert and weight.
w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
else:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
# For block quant, the scales are per block (typically 128x128).
w13_scale_data = torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
w2_scale_data = torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
# Note: name is weight_scale for tensor, weight_scale_inv for block.
layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
@@ -877,6 +869,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
@@ -893,158 +886,60 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale_inv, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight,
layer.w13_weight_scale_inv,
layer.w13_input_scale,
)
)
w2_weight, w2_weight_scale_inv, w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale
)
)
elif self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
def _convert_weights_to_kernel_format(
self,
layer: Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant
w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=w13_weight,
ws=w13_weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=w2_weight,
ws=w2_weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
elif self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13_weight = swap_w13_to_w31(w13_weight)
if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else:
w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
w2_weight = layer.w2_weight
w2_weight_scale_inv = layer.w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.fp8_backend == Fp8MoeBackend.AITER:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
dg_w13_weight, dg_w13_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w13_weight.data,
ws=layer.w13_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
)
dg_w2_weight, dg_w2_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w2_weight.data,
ws=layer.w2_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
)
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
dg_w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
)
# Reset the parameter
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
if w13_input_scale is not None:
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if w2_input_scale is not None:
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
if self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
assert not self.block_quant
# TODO(rob): this function is a hack that renames the scaling
# factors in the Module. This is a hack we should clean up.
register_moe_scaling_factors(layer)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
layer.w13_weight.data = w13_weight.data
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
# TODO(rob): we do this after replace_parameter() because
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
# directly. We will refactor this in a follow up PR.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
@@ -1053,6 +948,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def _setup_kernel(self, layer: Module) -> None:
"""Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating
# all of the kernels in the TP case to use mk. Once this is
# done, then we will initialzie the TP case and DP/EP case
@@ -1134,6 +1031,71 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
self.use_inplace = True
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Allow for accessing weights and scales in standard way.
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale, layer.w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
w13_weight, w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, layer.w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
w2_weight, w2_weight_scale, layer.w2_input_scale
)
)
# Per tensor kernels require single activation scale. Use the max.
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
assert layer.w13_input_scale is not None
assert layer.w2_input_scale is not None
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
if not self.block_quant:
shard_size = layer.intermediate_size_per_partition
max_w13_scales = w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
w13_weight[expert_id][start : start + shard_size, :],
w13_weight_scale[expert_id][shard_id],
)
w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
w13_weight_scale = max_w13_scales
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
)
# Setup modular kernel for TP case.
self._setup_kernel(layer)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -1453,32 +1415,26 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed.
if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
)
# Rushuffle weights for MARLIN if needed.
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Setup modular kernel for TP case.
self._setup_kernel(layer)
class Fp8KVCacheMethod(BaseKVCacheMethod):