[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-12-09 16:54:32 -05:00
committed by GitHub
parent 00e5cbb967
commit fccd532587
7 changed files with 206 additions and 86 deletions

View File

@@ -94,7 +94,7 @@ from vllm.model_executor.parameter import (
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
@@ -548,46 +548,50 @@ class Fp8LinearMethod(LinearMethodBase):
assert not self.act_q_static
size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale_inv = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale_inv
)
# Delete the weight_scale_inv parameter to avoid confusion
# with the weight_scale parameter
del layer.weight_scale_inv
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
# If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
# Update layer with new values.
layer.weight = Parameter(weight.data, requires_grad=False)
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
layer.input_scale = (
Parameter(input_scale, requires_grad=False)
if input_scale is not None
else None
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = (
process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if input_scale is not None:
replace_parameter(layer, "input_scale", input_scale)
else:
layer.input_scale = None
if self.use_marlin:
prepare_fp8_layer_for_marlin(
@@ -614,7 +618,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
@@ -643,10 +647,15 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
if self.block_quant:
weight_scale = layer.weight_scale_inv
else:
weight_scale = layer.weight_scale
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
@@ -660,7 +669,7 @@ class Fp8LinearMethod(LinearMethodBase):
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
@@ -937,22 +946,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight_scale_inv = layer.w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
w2_weight_scale_inv, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
@@ -990,13 +995,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(
replace_parameter(
layer,
"w13_weight_scale",
torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
@@ -1005,16 +1011,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
@@ -1035,12 +1042,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
@@ -1054,22 +1057,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
replace_parameter(layer, "w13_input_scale", w13_input_scale)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
@@ -1093,12 +1088,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w2_weight
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
if self.flashinfer_moe_backend is not None:
# NOTE: weights have to be swapped since the activation is