fp8 online quant: split out Fp8OnlineLinearMethod (#32189)

This commit is contained in:
Vasiliy Kuznetsov
2026-01-20 18:13:22 -05:00
committed by GitHub
parent 22375f8d13
commit d2389c1262
2 changed files with 143 additions and 110 deletions

View File

@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run(
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_load_fp16_model(
def test_online_quantization(
vllm_runner,
kv_cache_dtype: str,
force_marlin: bool,
@@ -191,6 +191,9 @@ def test_load_fp16_model(
llm.apply_model(check_model)
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),

View File

@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
quant_method = Fp8LinearMethod(self)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
if not self.is_checkpoint_fp8_serialized:
online_method = Fp8OnlineLinearMethod(self)
online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return online_method
else:
offline_method = Fp8LinearMethod(self)
offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return offline_method
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase):
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
1. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
@@ -388,89 +388,43 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size,
)
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
return res
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
scale = create_fp8_scale_parameter(
PerTensorScaleParameter,
output_partition_sizes,
input_size_per_partition,
None,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# WEIGHT SCALE
if not self.block_quant:
scale = create_fp8_scale_parameter(
PerTensorScaleParameter,
output_partition_sizes,
input_size_per_partition,
None,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
size_k_first = True
input_scale = None
# TODO(rob): refactor block quant into separate class.
@@ -488,31 +442,24 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint not serialized fp8, quantize the weights.
else:
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else:
weight = layer.weight
weight_scale = layer.weight_scale
weight = layer.weight
weight_scale = layer.weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = (
process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if not self.use_marlin:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
weight,
weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
if self.act_q_static:
assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
return res
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# TODO(future): support block_quant in online quant path
assert not self.block_quant
layer.input_scale = None
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
weight = qweight.t()
# Update layer with new values.
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
if self.use_marlin:
size_k_first = True
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and