Add BF16 upcast script and Blackwell DeepGEMM patch
- scripts/upcast_to_bf16.py: Converts mixed-precision V4 Pro to pure BF16 by upcasting all FP8 tensors (float8_e8m0fnu etc.) to bfloat16. Needed because modelopt PTQ calibration crashes on Blackwell with FP8 kernels (DeepGEMM unsupported, Triton finegrained-fp8 has K mismatches). - patches/patch_finegrained_fp8_blackwell.py: Patches transformers to reject DeepGEMM on SM100+ (Blackwell), letting it fall back to Triton. Note: the Triton fallback also fails during modelopt calibration on quantized weights, so upcasting to BF16 is the working solution.
This commit is contained in:
66
patches/patch_finegrained_fp8_blackwell.py
Normal file
66
patches/patch_finegrained_fp8_blackwell.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Patch transformers' finegrained_fp8.py to reject DeepGEMM on Blackwell (SM100+).
|
||||
|
||||
DeepGEMM only supports Hopper (SM90). On Blackwell GPUs, _load_deepgemm_kernel()
|
||||
passes the SM90 check but then fails trying to download/load the kernel from HF Hub
|
||||
(rate limits, missing builds). This patch adds a check for SM100+ that raises
|
||||
ImportError, which the existing try/except in w8a8_fp8_matmul catches, falling
|
||||
back to the Triton finegrained-fp8 kernel.
|
||||
|
||||
Also needed because the Triton finegrained-fp8 matmul has shape mismatches during
|
||||
modelopt calibration (K mismatch on quantized expert weights). The real fix is to
|
||||
upcast the model to BF16 first (see scripts/upcast_to_bf16.py).
|
||||
|
||||
Usage:
|
||||
python3 patch_finegrained_fp8_blackwell.py [path_to_finegrained_fp8.py]
|
||||
|
||||
If no path given, auto-detects from the installed transformers package.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def patch(fp8_file: str):
|
||||
with open(fp8_file) as f:
|
||||
content = f.read()
|
||||
|
||||
old = """ # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions
|
||||
major = torch.cuda.get_device_capability()[0]
|
||||
if major < 9:
|
||||
raise ImportError(
|
||||
f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device "
|
||||
f"has compute capability {major}.x. Use a different `experts_implementation`."
|
||||
)"""
|
||||
|
||||
new = """ # DeepGEMM requires Hopper (SM90) specifically - not yet supported on Blackwell (SM100+)
|
||||
major = torch.cuda.get_device_capability()[0]
|
||||
if major < 9:
|
||||
raise ImportError(
|
||||
f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device "
|
||||
f"has compute capability {major}.x. Use a different `experts_implementation`."
|
||||
)
|
||||
if major >= 10:
|
||||
raise ImportError(
|
||||
f"DeepGEMM is not yet supported on Blackwell (SM100+). "
|
||||
f"Use a different `experts_implementation`."
|
||||
)"""
|
||||
|
||||
if old in content:
|
||||
content = content.replace(old, new)
|
||||
with open(fp8_file, "w") as f:
|
||||
f.write(content)
|
||||
print(f"PATCHED: {fp8_file} — DeepGEMM now rejected on Blackwell (SM100+)")
|
||||
else:
|
||||
print("Patch target not found (may already be patched or different version)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
fp8_file = sys.argv[1]
|
||||
else:
|
||||
import transformers.integrations.finegrained_fp8 as fp8
|
||||
import inspect
|
||||
fp8_file = inspect.getfile(fp8)
|
||||
patch(fp8_file)
|
||||
84
scripts/upcast_to_bf16.py
Normal file
84
scripts/upcast_to_bf16.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Upcast a mixed-precision DeepSeek V4 Pro model to pure BF16.
|
||||
|
||||
Converts all FP8 tensors (float8_e8m0fnu, float8_e4m3fn, float8_e5m2)
|
||||
to bfloat16 so that modelopt's PTQ calibration can run without hitting
|
||||
broken FP8 kernel paths (DeepGEMM doesn't support Blackwell, and the
|
||||
Triton finegrained-fp8 matmul has shape mismatches during quantization).
|
||||
|
||||
Usage:
|
||||
python3 upcast_to_bf16.py /path/to/DeepSeek-V4-Pro /path/to/DeepSeek-V4-Pro-BF16
|
||||
|
||||
The output model will have the same shard structure, same config (with
|
||||
torch_dtype updated to bfloat16), and zero FP8 tensors.
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import argparse
|
||||
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
|
||||
FP8_DTYPES = (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
|
||||
|
||||
def upcast_model(model_dir: str, out_dir: str):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Copy non-safetensor files (config, tokenizer, etc.)
|
||||
for f in os.listdir(model_dir):
|
||||
fp = os.path.join(model_dir, f)
|
||||
if not f.endswith(".safetensors") and os.path.isfile(fp):
|
||||
shutil.copy2(fp, os.path.join(out_dir, f))
|
||||
print(f"Copied {f}")
|
||||
|
||||
# Convert safetensors shard by shard
|
||||
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
|
||||
total = len(safetensor_files)
|
||||
fp8_count = 0
|
||||
|
||||
for i, f in enumerate(safetensor_files):
|
||||
tensors = {}
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
t = sf.get_tensor(key)
|
||||
if t.dtype in FP8_DTYPES:
|
||||
t = t.to(torch.bfloat16)
|
||||
fp8_count += 1
|
||||
tensors[key] = t
|
||||
|
||||
out_path = os.path.join(out_dir, os.path.basename(f))
|
||||
save_file(tensors, out_path)
|
||||
del tensors # free memory
|
||||
if (i + 1) % 10 == 0 or i == total - 1:
|
||||
print(f"[{i + 1}/{total}] {os.path.basename(f)} (converted {fp8_count} FP8 tensors)")
|
||||
|
||||
print(f"\nDone! FP8->BF16 tensors: {fp8_count}")
|
||||
|
||||
# Verify: count remaining FP8 tensors
|
||||
remaining_fp8 = 0
|
||||
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors"))):
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
if sf.get_tensor(key).dtype in FP8_DTYPES:
|
||||
remaining_fp8 += 1
|
||||
print(f"Verification: {remaining_fp8} FP8 tensors remaining (should be 0)")
|
||||
|
||||
out_size = sum(
|
||||
os.path.getsize(os.path.join(out_dir, f))
|
||||
for f in os.listdir(out_dir)
|
||||
if f.endswith(".safetensors")
|
||||
)
|
||||
print(f"Output size: {out_size / 1e12:.2f} TB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Upcast DeepSeek V4 Pro mixed-precision to BF16")
|
||||
parser.add_argument("model_dir", help="Path to mixed-precision model")
|
||||
parser.add_argument("out_dir", help="Path to write BF16 model")
|
||||
args = parser.parse_args()
|
||||
upcast_model(args.model_dir, args.out_dir)
|
||||
Reference in New Issue
Block a user