Remove upcast_to_bf16.py — superseded by dequant_fp8_to_bf16.py
This commit is contained in:
@@ -1,84 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Upcast a mixed-precision DeepSeek V4 Pro model to pure BF16.
|
||||
|
||||
Converts all FP8 tensors (float8_e8m0fnu, float8_e4m3fn, float8_e5m2)
|
||||
to bfloat16 so that modelopt's PTQ calibration can run without hitting
|
||||
broken FP8 kernel paths (DeepGEMM doesn't support Blackwell, and the
|
||||
Triton finegrained-fp8 matmul has shape mismatches during quantization).
|
||||
|
||||
Usage:
|
||||
python3 upcast_to_bf16.py /path/to/DeepSeek-V4-Pro /path/to/DeepSeek-V4-Pro-BF16
|
||||
|
||||
The output model will have the same shard structure, same config (with
|
||||
torch_dtype updated to bfloat16), and zero FP8 tensors.
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import argparse
|
||||
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
|
||||
FP8_DTYPES = (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
|
||||
|
||||
def upcast_model(model_dir: str, out_dir: str):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Copy non-safetensor files (config, tokenizer, etc.)
|
||||
for f in os.listdir(model_dir):
|
||||
fp = os.path.join(model_dir, f)
|
||||
if not f.endswith(".safetensors") and os.path.isfile(fp):
|
||||
shutil.copy2(fp, os.path.join(out_dir, f))
|
||||
print(f"Copied {f}")
|
||||
|
||||
# Convert safetensors shard by shard
|
||||
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
|
||||
total = len(safetensor_files)
|
||||
fp8_count = 0
|
||||
|
||||
for i, f in enumerate(safetensor_files):
|
||||
tensors = {}
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
t = sf.get_tensor(key)
|
||||
if t.dtype in FP8_DTYPES:
|
||||
t = t.to(torch.bfloat16)
|
||||
fp8_count += 1
|
||||
tensors[key] = t
|
||||
|
||||
out_path = os.path.join(out_dir, os.path.basename(f))
|
||||
save_file(tensors, out_path)
|
||||
del tensors # free memory
|
||||
if (i + 1) % 10 == 0 or i == total - 1:
|
||||
print(f"[{i + 1}/{total}] {os.path.basename(f)} (converted {fp8_count} FP8 tensors)")
|
||||
|
||||
print(f"\nDone! FP8->BF16 tensors: {fp8_count}")
|
||||
|
||||
# Verify: count remaining FP8 tensors
|
||||
remaining_fp8 = 0
|
||||
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors"))):
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
if sf.get_tensor(key).dtype in FP8_DTYPES:
|
||||
remaining_fp8 += 1
|
||||
print(f"Verification: {remaining_fp8} FP8 tensors remaining (should be 0)")
|
||||
|
||||
out_size = sum(
|
||||
os.path.getsize(os.path.join(out_dir, f))
|
||||
for f in os.listdir(out_dir)
|
||||
if f.endswith(".safetensors")
|
||||
)
|
||||
print(f"Output size: {out_size / 1e12:.2f} TB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Upcast DeepSeek V4 Pro mixed-precision to BF16")
|
||||
parser.add_argument("model_dir", help="Path to mixed-precision model")
|
||||
parser.add_argument("out_dir", help="Path to write BF16 model")
|
||||
args = parser.parse_args()
|
||||
upcast_model(args.model_dir, args.out_dir)
|
||||
Reference in New Issue
Block a user