Add GPTQ support (#916)
This commit is contained in:
@@ -142,7 +142,7 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
supported_quantization = ["awq", "gptq", "squeezellm"]
|
||||
rocm_not_supported_quantization = ["awq"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@@ -179,7 +179,7 @@ class EngineArgs:
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'squeezellm', None],
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
return parser
|
||||
|
||||
@@ -38,8 +38,9 @@ class LLM:
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq". If None, we assume the model weights are not
|
||||
quantized and use `dtype` to determine the data type of the weights.
|
||||
we support "awq", "gptq" and "squeezellm". If None, we assume the
|
||||
model weights are not quantized and use `dtype` to determine the
|
||||
data type of the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Create weights for a linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
weight = Parameter(torch.empty(output_size,
|
||||
input_size,
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
@@ -102,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size, self.params_dtype)
|
||||
self.input_size, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
@@ -168,10 +174,12 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size_per_partition, self.params_dtype)
|
||||
self.input_size, self.output_size_per_partition, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@@ -295,10 +303,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -418,10 +428,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -481,10 +493,12 @@ class RowParallelLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size_per_partition, self.output_size, self.params_dtype)
|
||||
self.input_size_per_partition, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
|
||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
if input_size % self.quant_config.group_size != 0:
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
if output_size % self.quant_config.pack_factor != 0:
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
@@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size,
|
||||
output_size // self.quant_config.pack_factor,
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
@@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.group_size,
|
||||
output_size // self.quant_config.pack_factor,
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
@@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.group_size,
|
||||
output_size,
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
|
||||
215
vllm/model_executor/layers/quantization/gptq.py
Normal file
215
vllm/model_executor/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
# exllama kernel v1 only supports 4 bit
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"GPTQ, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
return cls(weight_bits, group_size, desc_act)
|
||||
|
||||
def get_linear_method(self) -> "GPTQLinearMethod":
|
||||
return GPTQLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
UNUSED = enum.auto()
|
||||
UNINITIALIZED = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
del output_size # Unused.
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if input_size != input_size_per_partition and self.quant_config.group_size != -1:
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
g_idx = Parameter(
|
||||
torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"g_idx": g_idx,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
"exllama_state": exllama_state,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
|
||||
torch.int)
|
||||
else:
|
||||
weights["g_idx"] = torch.empty((1, 1), device="meta")
|
||||
weights["exllama_state"] = ExllamaState.READY
|
||||
ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
|
||||
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
|
||||
weights["qzeros"], weights["scales"],
|
||||
weights["g_idx"],
|
||||
weights["exllama_state"] == ExllamaState.READY)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.reshape(out_shape)
|
||||
@@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
if input_size % self.quant_config.pack_factor != 0:
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.pack_factor,
|
||||
output_size,
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
@@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
|
||||
@@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
continue
|
||||
if "word_embeddings" in name:
|
||||
name = name.replace(".word_embeddings", "")
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -425,27 +425,32 @@ class FalconForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "query_key_value" in name:
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] +
|
||||
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
wq = loaded_weight.narrow(
|
||||
output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
|
||||
*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wk = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wv = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head + 1,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
||||
if output_dim is not None:
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] +
|
||||
(total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||
-1) + loaded_weight_shape[output_dim + 1:])
|
||||
wq = loaded_weight.narrow(
|
||||
output_dim + 1, 0,
|
||||
num_query_heads_per_kv_head).reshape(
|
||||
*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wk = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wv = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head + 1,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module):
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module):
|
||||
config.hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
|
||||
@@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -153,7 +153,7 @@ class MixtralMoE(nn.Module):
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
linear_method=None)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
@@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# pylint: disable=E1136
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
@@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
|
||||
self.c_attn = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
@@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
@@ -287,4 +287,5 @@ def initialize_dummy_weights(
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
param.data.uniform_(low, high)
|
||||
if torch.is_floating_point(param):
|
||||
param.data.uniform_(low, high)
|
||||
|
||||
Reference in New Issue
Block a user