[model][refactor] remove cuda hard code in models and layers (#13658)
This commit is contained in:
@@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(quant_state, device="cuda")
|
||||
return QuantState.from_dict(quant_state,
|
||||
device=current_platform.device_type)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
|
||||
Reference in New Issue
Block a user