diff --git a/quantize_modelopt.py b/quantize_modelopt.py index 1e01305..b500b90 100644 --- a/quantize_modelopt.py +++ b/quantize_modelopt.py @@ -110,7 +110,7 @@ def main(): model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs) - if not args.use_seq_device_map: + if not args.use_seq_device_map and not args.low_memory_mode: model = model.cuda() # Build calibration dataloader