[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

This commit is contained in:
Cody Yu
2024-04-26 13:41:14 -07:00
committed by GitHub
parent 603ad84815
commit a62aaf1df5
45 changed files with 759 additions and 713 deletions

View File

@@ -2,8 +2,33 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List
import torch
from torch import nn
from vllm.model_executor.layers.linear import LinearMethodBase
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class QuantizationConfig(ABC):
@@ -51,8 +76,8 @@ class QuantizationConfig(ABC):
"quantization config.")
@abstractmethod
def get_linear_method(self) -> LinearMethodBase:
"""Get the linear method to use for the quantized linear layer."""
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
"""Get the quantize method to use for the quantized layer."""
raise NotImplementedError
@abstractmethod