Add BF16 upcast script and Blackwell DeepGEMM patch

- scripts/upcast_to_bf16.py: Converts mixed-precision V4 Pro to pure BF16
  by upcasting all FP8 tensors (float8_e8m0fnu etc.) to bfloat16.
  Needed because modelopt PTQ calibration crashes on Blackwell with FP8
  kernels (DeepGEMM unsupported, Triton finegrained-fp8 has K mismatches).

- patches/patch_finegrained_fp8_blackwell.py: Patches transformers to
  reject DeepGEMM on SM100+ (Blackwell), letting it fall back to Triton.
  Note: the Triton fallback also fails during modelopt calibration on
  quantized weights, so upcasting to BF16 is the working solution.
This commit is contained in:
2026-05-07 14:25:20 +00:00
parent ef89ceffbd
commit 7a3b81e833
2 changed files with 150 additions and 0 deletions

View File

@@ -0,0 +1,66 @@
#!/usr/bin/env python3
"""
Patch transformers' finegrained_fp8.py to reject DeepGEMM on Blackwell (SM100+).
DeepGEMM only supports Hopper (SM90). On Blackwell GPUs, _load_deepgemm_kernel()
passes the SM90 check but then fails trying to download/load the kernel from HF Hub
(rate limits, missing builds). This patch adds a check for SM100+ that raises
ImportError, which the existing try/except in w8a8_fp8_matmul catches, falling
back to the Triton finegrained-fp8 kernel.
Also needed because the Triton finegrained-fp8 matmul has shape mismatches during
modelopt calibration (K mismatch on quantized expert weights). The real fix is to
upcast the model to BF16 first (see scripts/upcast_to_bf16.py).
Usage:
python3 patch_finegrained_fp8_blackwell.py [path_to_finegrained_fp8.py]
If no path given, auto-detects from the installed transformers package.
"""
import sys
import os
def patch(fp8_file: str):
with open(fp8_file) as f:
content = f.read()
old = """ # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions
major = torch.cuda.get_device_capability()[0]
if major < 9:
raise ImportError(
f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device "
f"has compute capability {major}.x. Use a different `experts_implementation`."
)"""
new = """ # DeepGEMM requires Hopper (SM90) specifically - not yet supported on Blackwell (SM100+)
major = torch.cuda.get_device_capability()[0]
if major < 9:
raise ImportError(
f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device "
f"has compute capability {major}.x. Use a different `experts_implementation`."
)
if major >= 10:
raise ImportError(
f"DeepGEMM is not yet supported on Blackwell (SM100+). "
f"Use a different `experts_implementation`."
)"""
if old in content:
content = content.replace(old, new)
with open(fp8_file, "w") as f:
f.write(content)
print(f"PATCHED: {fp8_file} — DeepGEMM now rejected on Blackwell (SM100+)")
else:
print("Patch target not found (may already be patched or different version)")
if __name__ == "__main__":
if len(sys.argv) > 1:
fp8_file = sys.argv[1]
else:
import transformers.integrations.finegrained_fp8 as fp8
import inspect
fp8_file = inspect.getfile(fp8)
patch(fp8_file)

84
scripts/upcast_to_bf16.py Normal file
View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python3
"""
Upcast a mixed-precision DeepSeek V4 Pro model to pure BF16.
Converts all FP8 tensors (float8_e8m0fnu, float8_e4m3fn, float8_e5m2)
to bfloat16 so that modelopt's PTQ calibration can run without hitting
broken FP8 kernel paths (DeepGEMM doesn't support Blackwell, and the
Triton finegrained-fp8 matmul has shape mismatches during quantization).
Usage:
python3 upcast_to_bf16.py /path/to/DeepSeek-V4-Pro /path/to/DeepSeek-V4-Pro-BF16
The output model will have the same shard structure, same config (with
torch_dtype updated to bfloat16), and zero FP8 tensors.
"""
import os
import glob
import shutil
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
import torch
FP8_DTYPES = (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.float8_e5m2)
def upcast_model(model_dir: str, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
# Copy non-safetensor files (config, tokenizer, etc.)
for f in os.listdir(model_dir):
fp = os.path.join(model_dir, f)
if not f.endswith(".safetensors") and os.path.isfile(fp):
shutil.copy2(fp, os.path.join(out_dir, f))
print(f"Copied {f}")
# Convert safetensors shard by shard
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
total = len(safetensor_files)
fp8_count = 0
for i, f in enumerate(safetensor_files):
tensors = {}
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
t = sf.get_tensor(key)
if t.dtype in FP8_DTYPES:
t = t.to(torch.bfloat16)
fp8_count += 1
tensors[key] = t
out_path = os.path.join(out_dir, os.path.basename(f))
save_file(tensors, out_path)
del tensors # free memory
if (i + 1) % 10 == 0 or i == total - 1:
print(f"[{i + 1}/{total}] {os.path.basename(f)} (converted {fp8_count} FP8 tensors)")
print(f"\nDone! FP8->BF16 tensors: {fp8_count}")
# Verify: count remaining FP8 tensors
remaining_fp8 = 0
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors"))):
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if sf.get_tensor(key).dtype in FP8_DTYPES:
remaining_fp8 += 1
print(f"Verification: {remaining_fp8} FP8 tensors remaining (should be 0)")
out_size = sum(
os.path.getsize(os.path.join(out_dir, f))
for f in os.listdir(out_dir)
if f.endswith(".safetensors")
)
print(f"Output size: {out_size / 1e12:.2f} TB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upcast DeepSeek V4 Pro mixed-precision to BF16")
parser.add_argument("model_dir", help="Path to mixed-precision model")
parser.add_argument("out_dir", help="Path to write BF16 model")
args = parser.parse_args()
upcast_model(args.model_dir, args.out_dir)