[Misc] Update fbgemmfp8 to use vLLMParameters (#7972)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Dipika Sikka
2024-09-03 22:12:41 -04:00
committed by GitHub
parent 61f4a93d14
commit e16fa99a6a
3 changed files with 22 additions and 41 deletions

View File

@@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
apply_fp8_linear)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -85,6 +86,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
@@ -95,20 +97,21 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer.orig_dtype = params_dtype
# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})
# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE UPPER BOUND
@@ -118,6 +121,11 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer.input_scale_ub = input_scale_ub
def process_weights_after_loading(self, layer: Module) -> None:
# required by torch.compile
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)