[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -2,8 +2,33 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, *weight_args,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
@@ -51,8 +76,8 @@ class QuantizationConfig(ABC):
|
||||
"quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_linear_method(self) -> LinearMethodBase:
|
||||
"""Get the linear method to use for the quantized linear layer."""
|
||||
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
|
||||
"""Get the quantize method to use for the quantized layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
Reference in New Issue
Block a user