[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -94,7 +94,7 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
@@ -548,46 +548,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
assert not self.act_q_static
|
||||
size_k_first = False
|
||||
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv
|
||||
)
|
||||
# Delete the weight_scale_inv parameter to avoid confusion
|
||||
# with the weight_scale parameter
|
||||
del layer.weight_scale_inv
|
||||
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
layer.input_scale = (
|
||||
Parameter(input_scale, requires_grad=False)
|
||||
if input_scale is not None
|
||||
else None
|
||||
)
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
if input_scale is not None:
|
||||
replace_parameter(layer, "input_scale", input_scale)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(
|
||||
@@ -614,7 +618,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -643,10 +647,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
if self.block_quant:
|
||||
weight_scale = layer.weight_scale_inv
|
||||
else:
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
@@ -660,7 +669,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -937,22 +946,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||
|
||||
# torch.compile() cannot use Parameter subclasses.
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
@@ -990,13 +995,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
@@ -1005,16 +1011,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
@@ -1035,12 +1042,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
@@ -1054,22 +1057,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
@@ -1093,12 +1088,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
|
||||
Reference in New Issue
Block a user