[Online Quant] [QeRL] Minor code cleanup (#38574)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8KVCacheMethod,
|
||||
Fp8OnlineLinearMethod,
|
||||
Fp8OnlineMoEMethod,
|
||||
_copy_missing_attrs,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_BLOCK_SIZE,
|
||||
@@ -43,11 +42,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
initialize_single_dummy_weight,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -183,17 +178,6 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
if layer.weight.device == torch.device("meta"):
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=layer.weight.weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
initialize_single_dummy_weight(layer.weight)
|
||||
|
||||
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
|
||||
|
||||
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
|
||||
@@ -265,28 +249,6 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // MXFP8_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
layer.weight_block_size = [1, MXFP8_BLOCK_SIZE]
|
||||
|
||||
def _quantize_mxfp8_moe_weight(
|
||||
@@ -309,34 +271,9 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
if layer.w13_weight.device == torch.device("meta"):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
initialize_single_dummy_weight(layer.w13_weight)
|
||||
if layer.w2_weight.device == torch.device("meta"):
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
initialize_single_dummy_weight(layer.w2_weight)
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
|
||||
@@ -230,7 +230,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
|
||||
continue
|
||||
|
||||
# reloading: place kernel tensors back as a fallback
|
||||
else:
|
||||
elif info.load_numel_total > 0: # type: ignore[operator]
|
||||
logger.warning("%s: Failed to load weights", layer.__class__.__name__)
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user