diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index e8abe0d41..9226f2a55 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run( @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] ) -def test_load_fp16_model( +def test_online_quantization( vllm_runner, kv_cache_dtype: str, force_marlin: bool, @@ -191,6 +191,9 @@ def test_load_fp16_model( llm.apply_model(check_model) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4) + print(outputs[0][1]) + @pytest.mark.skipif( not is_quant_method_supported("fp8"), diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6636c05cd..5adcd09b0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() - quant_method = Fp8LinearMethod(self) - quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) - return quant_method + if not self.is_checkpoint_fp8_serialized: + online_method = Fp8OnlineLinearMethod(self) + online_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return online_method + else: + offline_method = Fp8LinearMethod(self) + offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return offline_method elif isinstance(layer, FusedMoE): if is_layer_skipped( prefix=prefix, @@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase): Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - Limitations: - 1. Only support per-tensor quantization due to torch._scaled_mm support. - 2. Only support float8_e4m3fn data type due to the limitation of + 1. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) Args: @@ -388,89 +388,43 @@ class Fp8LinearMethod(LinearMethodBase): self.weight_block_size, ) - # WEIGHT - if self.quant_config.is_checkpoint_fp8_serialized: - weight = create_fp8_weight_parameter( - output_size_per_partition, input_size_per_partition, weight_loader - ) - else: - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Delete the bookkeeping - del layer._loaded_numel - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - return res - - # For non-serialized checkpoints, use original dtype - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=patched_weight_loader, - ) + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) layer.register_parameter("weight", weight) - # If checkpoint is serialized fp8, load them. - # Otherwise, wait until process_weights_after_loading. - if self.quant_config.is_checkpoint_fp8_serialized: - # WEIGHT SCALE - if not self.block_quant: - scale = create_fp8_scale_parameter( - PerTensorScaleParameter, - output_partition_sizes, - input_size_per_partition, - None, - weight_loader, - ) - set_weight_attrs(scale, {"scale_type": "weight_scale"}) - layer.register_parameter("weight_scale", scale) - else: - assert not self.act_q_static - assert self.weight_block_size is not None - scale = create_fp8_scale_parameter( - BlockQuantScaleParameter, - output_partition_sizes, - input_size_per_partition, - self.weight_block_size, - weight_loader, - ) - set_weight_attrs(scale, {"scale_type": "weight_scale"}) - # The weight_scale_inv name is intentional for deepseekv3 - layer.register_parameter("weight_scale_inv", scale) + # WEIGHT SCALE + if not self.block_quant: + scale = create_fp8_scale_parameter( + PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, + weight_loader, + ) + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + layer.register_parameter("weight_scale", scale) + else: + assert not self.act_q_static + assert self.weight_block_size is not None + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader, + ) + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", scale) - # INPUT ACTIVATION SCALE - if self.act_q_static: - scale = create_fp8_input_scale(output_partition_sizes, weight_loader) - set_weight_attrs(scale, {"scale_type": "input_scale"}) - layer.register_parameter("input_scale", scale) + # INPUT ACTIVATION SCALE + if self.act_q_static: + scale = create_fp8_input_scale(output_partition_sizes, weight_loader) + set_weight_attrs(scale, {"scale_type": "input_scale"}) + layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - size_k_first = True input_scale = None # TODO(rob): refactor block quant into separate class. @@ -488,31 +442,24 @@ class Fp8LinearMethod(LinearMethodBase): # If checkpoint not serialized fp8, quantize the weights. else: - if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - weight = qweight.t() - # If checkpoint is fp8 per-tensor, handle that there are N scales for N # shards in a fused module - else: - weight = layer.weight - weight_scale = layer.weight_scale + weight = layer.weight + weight_scale = layer.weight_scale - # If using w8a8, torch._scaled_mm needs per tensor, so - # requantize the logical shards as a single weight. - if not self.use_marlin: - weight, weight_scale, input_scale = ( - process_fp8_weight_tensor_strategy( - weight, - weight_scale, - layer.logical_widths, - getattr(layer, "input_scale", None), - ) - ) - if self.act_q_static: - assert input_scale is not None - input_scale = input_scale.max() - weight = weight.t() + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + if not self.use_marlin: + weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( + weight, + weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), + ) + if self.act_q_static: + assert input_scale is not None + input_scale = input_scale.max() + weight = weight.t() # Update layer with new values. replace_parameter(layer, "weight", weight.data) @@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase): return self.fp8_linear.apply_weights(layer, x, bias) +class Fp8OnlineLinearMethod(Fp8LinearMethod): + """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint + and quantized the weights during loading.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # WEIGHT + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + + # load the current weight chunk + copy_numel_counter = CopyNumelCounter() + with copy_numel_counter: + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + layer._loaded_numel += copy_numel_counter.copied_numel + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call from doing + # anything + layer._already_called_process_weights_after_loading = True + + return res + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=patched_weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + # TODO(future): support block_quant in online quant path + assert not self.block_quant + + layer.input_scale = None + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + weight = qweight.t() + + # Update layer with new values. + replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight_scale", weight_scale.data) + + if self.use_marlin: + size_k_first = True + prepare_fp8_layer_for_marlin( + layer, size_k_first, input_dtype=self.marlin_input_dtype + ) + # Activations not quantized for marlin. + + class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and