[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
class LinearMethodBase(ABC):
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -50,22 +51,15 @@ class LinearMethodBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization.
|
||||
@@ -92,10 +86,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
if self.separate_bias_add:
|
||||
if bias is not None:
|
||||
@@ -104,8 +98,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
|
||||
class ReplicatedLinear(torch.nn.Module):
|
||||
"""Replicated linear layer.
|
||||
class LinearBase(torch.nn.Module):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
@@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -134,12 +127,43 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
if quant_config is None:
|
||||
self.quant_method = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
self.quant_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
@@ -149,12 +173,12 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = self.linear_method.apply_weights(self, x, bias)
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
@@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
"""
|
||||
@@ -184,34 +208,26 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@@ -239,7 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(self, input_, bias)
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
@@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -278,13 +294,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias, gather_output,
|
||||
skip_bias_add, params_dtype, linear_method,
|
||||
skip_bias_add, params_dtype, quant_config,
|
||||
self.output_sizes)
|
||||
|
||||
def weight_loader(self,
|
||||
@@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -396,7 +412,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@@ -424,7 +440,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, bias, False, skip_bias_add,
|
||||
params_dtype, linear_method, output_sizes)
|
||||
params_dtype, quant_config, output_sizes)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -517,7 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
@@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -552,32 +568,24 @@ class RowParallelLinear(torch.nn.Module):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
@@ -616,8 +624,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self, input_parallel)
|
||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user