[Misc] Modify BNB parameter name (#9997)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -203,8 +203,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
qweight = create_qweight_for_8bit()
|
||||
else:
|
||||
qweight = create_qweight_for_4bit()
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
# Enable parameters to have the same name as in the BNB
|
||||
# checkpoint format.
|
||||
layer.register_parameter("weight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
@@ -234,7 +235,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
qweight = layer.weight
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
quant_states = qweight.bnb_quant_state
|
||||
matmul_states = qweight.matmul_state
|
||||
@@ -313,7 +314,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
qweight = layer.weight
|
||||
quant_states = qweight.bnb_quant_state
|
||||
offsets = qweight.bnb_shard_offsets
|
||||
|
||||
|
||||
Reference in New Issue
Block a user