[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

This commit is contained in:
Cody Yu
2024-04-26 13:41:14 -07:00
committed by GitHub
parent 603ad84815
commit a62aaf1df5
45 changed files with 759 additions and 713 deletions

View File

@@ -1,16 +1,17 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
class FP8Config(QuantizationConfig):
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@classmethod
@@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig):
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
return cls()
def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: FP8Config):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(
@@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scaling_factor", w_scale)
def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers,
# Although the quant_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
return
qweight, weight_scale = per_tensor_quantize(layer.weight)
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = ops.scaled_fp8_quant(x)
output, _ = torch._scaled_mm(
qinput,
layer.weight,
@@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias,
)
return output
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale