fp8 online quant: split out Fp8OnlineLinearMethod (#32189)
This commit is contained in:
committed by
GitHub
parent
22375f8d13
commit
d2389c1262
@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run(
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
def test_load_fp16_model(
|
||||
def test_online_quantization(
|
||||
vllm_runner,
|
||||
kv_cache_dtype: str,
|
||||
force_marlin: bool,
|
||||
@@ -191,6 +191,9 @@ def test_load_fp16_model(
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
|
||||
@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig):
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
quant_method = Fp8LinearMethod(self)
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
if not self.is_checkpoint_fp8_serialized:
|
||||
online_method = Fp8OnlineLinearMethod(self)
|
||||
online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return online_method
|
||||
else:
|
||||
offline_method = Fp8LinearMethod(self)
|
||||
offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return offline_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||
2. Only support float8_e4m3fn data type due to the limitation of
|
||||
1. Only support float8_e4m3fn data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
||||
|
||||
Args:
|
||||
@@ -388,89 +388,43 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
else:
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
layer._loaded_numel += copy_numel_counter.copied_numel
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call from doing
|
||||
# anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
return res
|
||||
|
||||
# For non-serialized checkpoints, use original dtype
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = create_fp8_scale_parameter(
|
||||
PerTensorScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
None,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
scale = create_fp8_scale_parameter(
|
||||
BlockQuantScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
scale = create_fp8_scale_parameter(
|
||||
PerTensorScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
None,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
scale = create_fp8_scale_parameter(
|
||||
BlockQuantScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.act_q_static:
|
||||
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.act_q_static:
|
||||
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
size_k_first = True
|
||||
input_scale = None
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
@@ -488,31 +442,24 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
else:
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
|
||||
and quantized the weights during loading."""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
layer._loaded_numel += copy_numel_counter.copied_numel
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call from doing
|
||||
# anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
return res
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# TODO(future): support block_quant in online quant path
|
||||
assert not self.block_quant
|
||||
|
||||
layer.input_scale = None
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
if self.use_marlin:
|
||||
size_k_first = True
|
||||
prepare_fp8_layer_for_marlin(
|
||||
layer, size_k_first, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
|
||||
Reference in New Issue
Block a user