Remove upcast_to_bf16.py — superseded by dequant_fp8_to_bf16.py

This commit is contained in:
2026-05-08 17:13:39 +00:00
parent ca9a4f5eaa
commit f1d21900ea

View File

@@ -1,84 +0,0 @@
#!/usr/bin/env python3
"""
Upcast a mixed-precision DeepSeek V4 Pro model to pure BF16.
Converts all FP8 tensors (float8_e8m0fnu, float8_e4m3fn, float8_e5m2)
to bfloat16 so that modelopt's PTQ calibration can run without hitting
broken FP8 kernel paths (DeepGEMM doesn't support Blackwell, and the
Triton finegrained-fp8 matmul has shape mismatches during quantization).
Usage:
python3 upcast_to_bf16.py /path/to/DeepSeek-V4-Pro /path/to/DeepSeek-V4-Pro-BF16
The output model will have the same shard structure, same config (with
torch_dtype updated to bfloat16), and zero FP8 tensors.
"""
import os
import glob
import shutil
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
import torch
FP8_DTYPES = (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.float8_e5m2)
def upcast_model(model_dir: str, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
# Copy non-safetensor files (config, tokenizer, etc.)
for f in os.listdir(model_dir):
fp = os.path.join(model_dir, f)
if not f.endswith(".safetensors") and os.path.isfile(fp):
shutil.copy2(fp, os.path.join(out_dir, f))
print(f"Copied {f}")
# Convert safetensors shard by shard
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
total = len(safetensor_files)
fp8_count = 0
for i, f in enumerate(safetensor_files):
tensors = {}
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
t = sf.get_tensor(key)
if t.dtype in FP8_DTYPES:
t = t.to(torch.bfloat16)
fp8_count += 1
tensors[key] = t
out_path = os.path.join(out_dir, os.path.basename(f))
save_file(tensors, out_path)
del tensors # free memory
if (i + 1) % 10 == 0 or i == total - 1:
print(f"[{i + 1}/{total}] {os.path.basename(f)} (converted {fp8_count} FP8 tensors)")
print(f"\nDone! FP8->BF16 tensors: {fp8_count}")
# Verify: count remaining FP8 tensors
remaining_fp8 = 0
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors"))):
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if sf.get_tensor(key).dtype in FP8_DTYPES:
remaining_fp8 += 1
print(f"Verification: {remaining_fp8} FP8 tensors remaining (should be 0)")
out_size = sum(
os.path.getsize(os.path.join(out_dir, f))
for f in os.listdir(out_dir)
if f.endswith(".safetensors")
)
print(f"Output size: {out_size / 1e12:.2f} TB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upcast DeepSeek V4 Pro mixed-precision to BF16")
parser.add_argument("model_dir", help="Path to mixed-precision model")
parser.add_argument("out_dir", help="Path to write BF16 model")
args = parser.parse_args()
upcast_model(args.model_dir, args.out_dir)