Add GPTQ support (#916)

This commit is contained in:
CHU Tianxiang
2023-12-15 19:04:22 +08:00
committed by GitHub
parent c06170cc8e
commit 0fbfc4b81b
35 changed files with 1782 additions and 82 deletions

View File

@@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
@@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)