diff --git a/patches/patch_finegrained_fp8_blackwell.py b/patches/patch_finegrained_fp8_blackwell.py new file mode 100644 index 0000000..c79f5af --- /dev/null +++ b/patches/patch_finegrained_fp8_blackwell.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +Patch transformers' finegrained_fp8.py to reject DeepGEMM on Blackwell (SM100+). + +DeepGEMM only supports Hopper (SM90). On Blackwell GPUs, _load_deepgemm_kernel() +passes the SM90 check but then fails trying to download/load the kernel from HF Hub +(rate limits, missing builds). This patch adds a check for SM100+ that raises +ImportError, which the existing try/except in w8a8_fp8_matmul catches, falling +back to the Triton finegrained-fp8 kernel. + +Also needed because the Triton finegrained-fp8 matmul has shape mismatches during +modelopt calibration (K mismatch on quantized expert weights). The real fix is to +upcast the model to BF16 first (see scripts/upcast_to_bf16.py). + +Usage: + python3 patch_finegrained_fp8_blackwell.py [path_to_finegrained_fp8.py] + +If no path given, auto-detects from the installed transformers package. +""" + +import sys +import os + + +def patch(fp8_file: str): + with open(fp8_file) as f: + content = f.read() + + old = """ # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + )""" + + new = """ # DeepGEMM requires Hopper (SM90) specifically - not yet supported on Blackwell (SM100+) + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + if major >= 10: + raise ImportError( + f"DeepGEMM is not yet supported on Blackwell (SM100+). " + f"Use a different `experts_implementation`." + )""" + + if old in content: + content = content.replace(old, new) + with open(fp8_file, "w") as f: + f.write(content) + print(f"PATCHED: {fp8_file} — DeepGEMM now rejected on Blackwell (SM100+)") + else: + print("Patch target not found (may already be patched or different version)") + + +if __name__ == "__main__": + if len(sys.argv) > 1: + fp8_file = sys.argv[1] + else: + import transformers.integrations.finegrained_fp8 as fp8 + import inspect + fp8_file = inspect.getfile(fp8) + patch(fp8_file) diff --git a/scripts/upcast_to_bf16.py b/scripts/upcast_to_bf16.py new file mode 100644 index 0000000..5d1fcd4 --- /dev/null +++ b/scripts/upcast_to_bf16.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" +Upcast a mixed-precision DeepSeek V4 Pro model to pure BF16. + +Converts all FP8 tensors (float8_e8m0fnu, float8_e4m3fn, float8_e5m2) +to bfloat16 so that modelopt's PTQ calibration can run without hitting +broken FP8 kernel paths (DeepGEMM doesn't support Blackwell, and the +Triton finegrained-fp8 matmul has shape mismatches during quantization). + +Usage: + python3 upcast_to_bf16.py /path/to/DeepSeek-V4-Pro /path/to/DeepSeek-V4-Pro-BF16 + +The output model will have the same shard structure, same config (with +torch_dtype updated to bfloat16), and zero FP8 tensors. +""" + +import os +import glob +import shutil +import argparse + +from safetensors import safe_open +from safetensors.torch import save_file +import torch + +FP8_DTYPES = (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.float8_e5m2) + + +def upcast_model(model_dir: str, out_dir: str): + os.makedirs(out_dir, exist_ok=True) + + # Copy non-safetensor files (config, tokenizer, etc.) + for f in os.listdir(model_dir): + fp = os.path.join(model_dir, f) + if not f.endswith(".safetensors") and os.path.isfile(fp): + shutil.copy2(fp, os.path.join(out_dir, f)) + print(f"Copied {f}") + + # Convert safetensors shard by shard + safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))) + total = len(safetensor_files) + fp8_count = 0 + + for i, f in enumerate(safetensor_files): + tensors = {} + with safe_open(f, framework="pt") as sf: + for key in sf.keys(): + t = sf.get_tensor(key) + if t.dtype in FP8_DTYPES: + t = t.to(torch.bfloat16) + fp8_count += 1 + tensors[key] = t + + out_path = os.path.join(out_dir, os.path.basename(f)) + save_file(tensors, out_path) + del tensors # free memory + if (i + 1) % 10 == 0 or i == total - 1: + print(f"[{i + 1}/{total}] {os.path.basename(f)} (converted {fp8_count} FP8 tensors)") + + print(f"\nDone! FP8->BF16 tensors: {fp8_count}") + + # Verify: count remaining FP8 tensors + remaining_fp8 = 0 + for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors"))): + with safe_open(f, framework="pt") as sf: + for key in sf.keys(): + if sf.get_tensor(key).dtype in FP8_DTYPES: + remaining_fp8 += 1 + print(f"Verification: {remaining_fp8} FP8 tensors remaining (should be 0)") + + out_size = sum( + os.path.getsize(os.path.join(out_dir, f)) + for f in os.listdir(out_dir) + if f.endswith(".safetensors") + ) + print(f"Output size: {out_size / 1e12:.2f} TB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upcast DeepSeek V4 Pro mixed-precision to BF16") + parser.add_argument("model_dir", help="Path to mixed-precision model") + parser.add_argument("out_dir", help="Path to write BF16 model") + args = parser.parse_args() + upcast_model(args.model_dir, args.out_dir)