diff --git a/scripts/dequant_fp8_to_bf16.py b/scripts/dequant_fp8_to_bf16.py index 9f3f5a9..ca947fb 100644 --- a/scripts/dequant_fp8_to_bf16.py +++ b/scripts/dequant_fp8_to_bf16.py @@ -210,6 +210,11 @@ def dequantize_model(model_dir: str, out_dir: str): stats["scales_removed"] += 1 out_path = os.path.join(out_dir, os.path.basename(f)) + if os.path.exists(out_path) and os.path.getsize(out_path) > 0: + # Resume: skip already-dequantized shards + print(f"[{i+1}/{total_shards}] Skipping (already done): {os.path.basename(f)}") + del tensors, scales_in_shard + continue save_file(tensors, out_path) shard_time = time.time() - shard_start