diff --git a/scripts/dequant_fp8_to_bf16.py b/scripts/dequant_fp8_to_bf16.py new file mode 100644 index 0000000..649db8b --- /dev/null +++ b/scripts/dequant_fp8_to_bf16.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +""" +Dequantize FP8 weights to BF16 in-place for DeepSeek V4 Pro mixed-precision model. + +For each FP8 weight tensor (float8_e4m3fn) paired with a block-wise scale +(float8_e8m0fnu), reconstructs the BF16 weight as: + bf16_weight = fp8_weight * scale_expanded + +After dequantization, FP8Linear.forward() sees element_size() > 1 and +falls back to F.linear(), avoiding the broken FP8 kernel paths on Blackwell. + +Preserves all BF16 and FP32 tensors unchanged. +Removes the now-unnecessary scale tensors from the output. +""" + +import os, glob, json, shutil +from safetensors import safe_open +from safetensors.torch import save_file +import torch + +FP8_WEIGHT_DTYPE = torch.float8_e4m3fn +FP8_SCALE_DTYPE = torch.float8_e8m0fnu +BLOCK_SIZE = (128, 128) + + +def dequantize_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequantize block-wise FP8 weight to BF16. + + fp8_weight: (out_features, in_features) float8_e4m3fn + scale: (out_features//128, in_features//128) float8_e8m0fnu or float32 + """ + scale_f32 = scale.float() + + out_features, in_features = fp8_weight.shape + scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE[0], dim=0).repeat_interleave(BLOCK_SIZE[1], dim=1) + scale_expanded = scale_expanded[:out_features, :in_features] + + weight_bf16 = fp8_weight.float() * scale_expanded + return weight_bf16.to(torch.bfloat16) + + +def dequantize_model(model_dir: str, out_dir: str): + os.makedirs(out_dir, exist_ok=True) + + # Copy non-safetensor files + for f in os.listdir(model_dir): + fp = os.path.join(model_dir, f) + if not f.endswith(".safetensors") and os.path.isfile(fp): + shutil.copy2(fp, os.path.join(out_dir, f)) + print(f"Copied {f}") + + safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))) + total = len(safetensor_files) + + # First pass: build map of scale_key -> weight_key + # Pattern: layers.X.attn.Y.scale -> layers.X.attn.Y.weight + scales_map = {} + + print("Scanning for FP8 weight+scale pairs...") + for f in safetensor_files: + with safe_open(f, framework="pt") as sf: + for key in sf.keys(): + if key.endswith(".scale"): + weight_key = key[:-len(".scale")] + ".weight" + scales_map[weight_key] = key + print(f"Found {len(scales_map)} FP8 weight+scale pairs") + + # Second pass: dequantize and save + fp8_dequantized = 0 + fp8_scales_removed = 0 + scale_keys_global = set(scales_map.values()) + + for i, f in enumerate(safetensor_files): + tensors = {} + scales_in_shard = {} + fp8_weights_in_shard = {} + + with safe_open(f, framework="pt") as sf: + for key in sf.keys(): + t = sf.get_tensor(key) + + if key in scale_keys_global: + # This is a scale tensor - save for dequantization + scales_in_shard[key] = t + elif key in scales_map and t.dtype == FP8_WEIGHT_DTYPE: + # FP8 weight that has a corresponding scale + fp8_weights_in_shard[key] = t + else: + # Regular tensor (BF16, FP32, or FP8 without scale) - keep as is + tensors[key] = t + + # Dequantize FP8 weights + for weight_key, fp8_w in fp8_weights_in_shard.items(): + scale_key = scales_map[weight_key] + scale = scales_in_shard.get(scale_key) + if scale is not None: + bf16_weight = dequantize_weight(fp8_w, scale) + tensors[weight_key] = bf16_weight + fp8_dequantized += 1 + del scales_in_shard[scale_key] + fp8_scales_removed += 1 + else: + # Scale not in this shard (shouldn't happen but handle gracefully) + print(f"WARNING: scale {scale_key} not found in same shard as {weight_key}") + tensors[weight_key] = fp8_w # keep as-is + + out_path = os.path.join(out_dir, os.path.basename(f)) + save_file(tensors, out_path) + del tensors, scales_in_shard, fp8_weights_in_shard + + if (i + 1) % 10 == 0 or i == total - 1: + print(f"[{i+1}/{total}] {os.path.basename(f)} (dequantized {fp8_dequantized} FP8, removed {fp8_scales_removed} scales)") + + # Update config + cfg_path = os.path.join(out_dir, "config.json") + if os.path.exists(cfg_path): + cfg = json.load(open(cfg_path)) + cfg["torch_dtype"] = "bfloat16" + cfg["_experts_implementation"] = "eager" + if "quantization_config" in cfg: + del cfg["quantization_config"] + json.dump(cfg, open(cfg_path, "w"), indent=2) + print(f"Updated config.json: torch_dtype=bfloat16, _experts_implementation=eager") + + print(f"\nDone! Dequantized {fp8_dequantized} FP8 weights, removed {fp8_scales_removed} scale tensors") + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Dequantize FP8 weights to BF16") + parser.add_argument("model_dir", help="Path to mixed-precision model") + parser.add_argument("out_dir", help="Path to write dequantized BF16 model") + args = parser.parse_args() + dequantize_model(args.model_dir, args.out_dir)