[model][refactor] remove cuda hard code in models and layers (#13658)

This commit is contained in:
Mengqing Cao
2025-02-24 22:10:14 +08:00
committed by GitHub
parent 437b76ff59
commit 23eca9cf68
7 changed files with 29 additions and 14 deletions

View File

@@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state, device="cuda")
return QuantState.from_dict(quant_state,
device=current_platform.device_type)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state