[Misc] Modify BNB parameter name (#9997)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-11-06 03:40:08 +08:00
committed by GitHub
parent d2e80332a7
commit b9c64c0ca7
3 changed files with 11 additions and 14 deletions

View File

@@ -203,8 +203,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
qweight = create_qweight_for_8bit()
else:
qweight = create_qweight_for_4bit()
layer.register_parameter("qweight", qweight)
# Enable parameters to have the same name as in the BNB
# checkpoint format.
layer.register_parameter("weight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
def apply(self,
@@ -234,7 +235,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
qweight = layer.weight
offsets = qweight.bnb_shard_offsets
quant_states = qweight.bnb_quant_state
matmul_states = qweight.matmul_state
@@ -313,7 +314,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
qweight = layer.weight
quant_states = qweight.bnb_quant_state
offsets = qweight.bnb_shard_offsets