[Kernel] Initial Activation Quantization Support (#4525)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
@@ -56,7 +56,6 @@ class LinearMethodBase(QuantizeMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -77,8 +76,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
@@ -149,15 +147,13 @@ class ReplicatedLinear(LinearBase):
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
@@ -210,17 +206,15 @@ class ColumnParallelLinear(LinearBase):
|
||||
the list would be size 3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
@@ -228,18 +222,26 @@ class ColumnParallelLinear(LinearBase):
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
assert self.quant_method is not None
|
||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@@ -317,22 +319,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias, gather_output,
|
||||
skip_bias_add, params_dtype, quant_config,
|
||||
self.output_sizes)
|
||||
super().__init__(input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -343,6 +347,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
param_shard_splitter = getattr(param, "shard_splitter", None)
|
||||
|
||||
if output_dim is not None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support output_dim != None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
# If a parameter has defined a shard_splitter to be used for
|
||||
# the weight, it should be applied before the weight is
|
||||
# loaded/copied to the parameter. The shard_splitter applies
|
||||
# logic by using the loaded_shard_id to ensure that the loaded
|
||||
# param is loaded to the correct location
|
||||
# within the parameter defined by the linear method.
|
||||
if loaded_shard_id is None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support loaded_shard_id == None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
@@ -403,6 +427,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
|
||||
# If a param_shard_splitter is defined by the LinearMethod, use it.
|
||||
elif param_shard_splitter is not None:
|
||||
logical_widths = getattr(param, "logical_widths", None)
|
||||
param_data, loaded_weight = param_shard_splitter(
|
||||
param_data, loaded_weight, loaded_shard_id, logical_widths)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
@@ -415,6 +446,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
|
||||
if fp8_scales_shard_indexer is None:
|
||||
if len(param_data.shape) == 0:
|
||||
param_data = param_data.reshape(1)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -443,17 +482,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
@@ -473,14 +510,19 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
output_sizes = [
|
||||
self.num_heads * tp_size * self.head_size,
|
||||
self.num_kv_heads * tp_size * self.head_size,
|
||||
self.num_kv_heads * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, bias, False, skip_bias_add,
|
||||
params_dtype, quant_config, output_sizes)
|
||||
super().__init__(input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -490,6 +532,26 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
param_shard_splitter = getattr(param, "shard_splitter", None)
|
||||
|
||||
if output_dim is not None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support output_dim != None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
# If a parameter has defined a shard_splitter to be used for
|
||||
# the weight, it should be applied before the weight is
|
||||
# loaded/copied to the parameter. The shard_splitter applies
|
||||
# logic by using the loaded_shard_id to ensure that the loaded
|
||||
# param is loaded to the correct location
|
||||
# within the parameter defined by the linear method.
|
||||
if loaded_shard_id is None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support loaded_shard_id == None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
@@ -528,6 +590,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
@@ -567,6 +631,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||
shard_size)
|
||||
# If a param_shard_splitter is defined by the LinearMethod, use it.
|
||||
elif param_shard_splitter is not None:
|
||||
logical_widths = getattr(param, "logical_widths", None)
|
||||
param_data, loaded_weight = param_shard_splitter(
|
||||
param_data, loaded_weight, loaded_shard_id, logical_widths)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
@@ -578,6 +648,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
|
||||
if len(param_data.shape) == 0:
|
||||
param_data = param_data.reshape(1)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -608,17 +685,15 @@ class RowParallelLinear(LinearBase):
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
@@ -628,16 +703,15 @@ class RowParallelLinear(LinearBase):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
@@ -665,12 +739,16 @@ class RowParallelLinear(LinearBase):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||
loaded_weight,
|
||||
shard_id=0)
|
||||
|
||||
if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user