[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class FP8Config(QuantizationConfig):
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@classmethod
|
||||
@@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
return cls()
|
||||
|
||||
def get_linear_method(self) -> "Fp8LinearMethod":
|
||||
return Fp8LinearMethod(self)
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return Fp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: FP8Config):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
@@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight_scaling_factor", w_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Although the linear_method is propagated to all layers,
|
||||
# Although the quant_method is propagated to all layers,
|
||||
# only linear layers invoke "create_weights". So we check
|
||||
# whether "weight_scaling_facor" is registered to determine
|
||||
# whether the layer is a linear layer that requires quantization.
|
||||
if not hasattr(layer, "weight_scaling_factor"):
|
||||
return
|
||||
|
||||
qweight, weight_scale = per_tensor_quantize(layer.weight)
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
|
||||
# torch._scaled_mm requires column-major in the second
|
||||
# input (weight), so we transpose the quantized weight.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scaling_factor.data.copy_(weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qinput, x_scale = per_tensor_quantize(x)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
@@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
||||
"""Quantize a tensor using per-tensor static scaling factor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
"""
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
# Calculate the scale as dtype max divided by absmax.
|
||||
# Since .abs() creates a new tensor, we use aminmax to get
|
||||
# the min and max first and then calculate the absmax.
|
||||
min_val, max_val = tensor.aminmax()
|
||||
amax = min_val.abs().max(max_val.abs())
|
||||
scale = finfo.max / amax.clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
Reference in New Issue
Block a user