From 5869f69c5f543c4b3cbd460a30c10eb97e5b3ca0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 31 Mar 2026 10:56:43 -0400 Subject: [PATCH] [Online Quant] [QeRL] Minor code cleanup (#38574) Signed-off-by: Kyle Sayers --- .../layers/quantization/mxfp8.py | 65 +------------------ .../model_loader/reload/layerwise.py | 2 +- 2 files changed, 2 insertions(+), 65 deletions(-) diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py index bd29f272b..07c519b2a 100644 --- a/vllm/model_executor/layers/quantization/mxfp8.py +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.fp8 import ( Fp8KVCacheMethod, Fp8OnlineLinearMethod, Fp8OnlineMoEMethod, - _copy_missing_attrs, ) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_BLOCK_SIZE, @@ -43,11 +42,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) -from vllm.model_executor.model_loader.weight_utils import ( - initialize_single_dummy_weight, -) -from vllm.model_executor.parameter import ModelWeightParameter -from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.model_executor.utils import replace_parameter from vllm.platforms import current_platform logger = init_logger(__name__) @@ -183,17 +178,6 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): if getattr(layer, "_already_called_process_weights_after_loading", False): return - if layer.weight.device == torch.device("meta"): - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=layer.weight.weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - initialize_single_dummy_weight(layer.weight) - weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous()) if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS: @@ -265,28 +249,6 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): **extra_weight_attrs, ) - w13_weight_scale = torch.nn.Parameter( - torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // MXFP8_BLOCK_SIZE, - dtype=torch.uint8, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition // MXFP8_BLOCK_SIZE, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) layer.weight_block_size = [1, MXFP8_BLOCK_SIZE] def _quantize_mxfp8_moe_weight( @@ -309,34 +271,9 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): if getattr(layer, "_already_called_process_weights_after_loading", False): return - if layer.w13_weight.device == torch.device("meta"): - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w13_weight, {"weight_loader": layer.w13_weight.weight_loader} - ) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - initialize_single_dummy_weight(layer.w13_weight) - if layer.w2_weight.device == torch.device("meta"): - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w2_weight, {"weight_loader": layer.w2_weight.weight_loader} - ) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - initialize_single_dummy_weight(layer.w2_weight) - fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) - w13_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale layer.w13_input_scale = None layer.w2_input_scale = None diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 3ad0cdb84..f7b1c13bb 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -230,7 +230,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon continue # reloading: place kernel tensors back as a fallback - else: + elif info.load_numel_total > 0: # type: ignore[operator] logger.warning("%s: Failed to load weights", layer.__class__.__name__) _place_kernel_tensors(layer, info)