support bitsandbytes quantization with more models (#9148)

This commit is contained in:
chenqianfzh
2024-10-08 18:52:19 -07:00
committed by GitHub
parent 9ba0bd6aa6
commit 2f4117c38e
10 changed files with 165 additions and 28 deletions

View File

@@ -315,6 +315,19 @@ class OPTModel(nn.Module):
class OPTForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".out_proj.", ".fc2."]
def __init__(
self,
config: OPTConfig,