#!/usr/bin/env python3 """ Dequantize FP8 weights to BF16 in-place for DeepSeek V4 Pro mixed-precision model. For each FP8 weight tensor (float8_e4m3fn) paired with a block-wise scale (float8_e8m0fnu), reconstructs the BF16 weight as: bf16_weight = fp8_weight * scale_expanded After dequantization, FP8Linear.forward() sees element_size() > 1 and falls back to F.linear(), avoiding the broken FP8 kernel paths on Blackwell. Preserves all BF16 and FP32 tensors unchanged. Removes the now-unnecessary scale tensors from the output. """ import os, glob, json, shutil from safetensors import safe_open from safetensors.torch import save_file import torch FP8_WEIGHT_DTYPE = torch.float8_e4m3fn FP8_SCALE_DTYPE = torch.float8_e8m0fnu BLOCK_SIZE = (128, 128) def dequantize_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Dequantize block-wise FP8 weight to BF16. fp8_weight: (out_features, in_features) float8_e4m3fn scale: (out_features//128, in_features//128) float8_e8m0fnu or float32 """ scale_f32 = scale.float() out_features, in_features = fp8_weight.shape scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE[0], dim=0).repeat_interleave(BLOCK_SIZE[1], dim=1) scale_expanded = scale_expanded[:out_features, :in_features] weight_bf16 = fp8_weight.float() * scale_expanded return weight_bf16.to(torch.bfloat16) def dequantize_model(model_dir: str, out_dir: str): os.makedirs(out_dir, exist_ok=True) # Copy non-safetensor files for f in os.listdir(model_dir): fp = os.path.join(model_dir, f) if not f.endswith(".safetensors") and os.path.isfile(fp): shutil.copy2(fp, os.path.join(out_dir, f)) print(f"Copied {f}") safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))) total = len(safetensor_files) # First pass: build map of scale_key -> weight_key # Pattern: layers.X.attn.Y.scale -> layers.X.attn.Y.weight scales_map = {} print("Scanning for FP8 weight+scale pairs...") for f in safetensor_files: with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key.endswith(".scale"): weight_key = key[:-len(".scale")] + ".weight" scales_map[weight_key] = key print(f"Found {len(scales_map)} FP8 weight+scale pairs") # Second pass: dequantize and save fp8_dequantized = 0 fp8_scales_removed = 0 scale_keys_global = set(scales_map.values()) for i, f in enumerate(safetensor_files): tensors = {} scales_in_shard = {} fp8_weights_in_shard = {} with safe_open(f, framework="pt") as sf: for key in sf.keys(): t = sf.get_tensor(key) if key in scale_keys_global: # This is a scale tensor - save for dequantization scales_in_shard[key] = t elif key in scales_map and t.dtype == FP8_WEIGHT_DTYPE: # FP8 weight that has a corresponding scale fp8_weights_in_shard[key] = t else: # Regular tensor (BF16, FP32, or FP8 without scale) - keep as is tensors[key] = t # Dequantize FP8 weights for weight_key, fp8_w in fp8_weights_in_shard.items(): scale_key = scales_map[weight_key] scale = scales_in_shard.get(scale_key) if scale is not None: bf16_weight = dequantize_weight(fp8_w, scale) tensors[weight_key] = bf16_weight fp8_dequantized += 1 del scales_in_shard[scale_key] fp8_scales_removed += 1 else: # Scale not in this shard (shouldn't happen but handle gracefully) print(f"WARNING: scale {scale_key} not found in same shard as {weight_key}") tensors[weight_key] = fp8_w # keep as-is out_path = os.path.join(out_dir, os.path.basename(f)) save_file(tensors, out_path) del tensors, scales_in_shard, fp8_weights_in_shard if (i + 1) % 10 == 0 or i == total - 1: print(f"[{i+1}/{total}] {os.path.basename(f)} (dequantized {fp8_dequantized} FP8, removed {fp8_scales_removed} scales)") # Update config cfg_path = os.path.join(out_dir, "config.json") if os.path.exists(cfg_path): cfg = json.load(open(cfg_path)) cfg["torch_dtype"] = "bfloat16" cfg["_experts_implementation"] = "eager" if "quantization_config" in cfg: del cfg["quantization_config"] json.dump(cfg, open(cfg_path, "w"), indent=2) print(f"Updated config.json: torch_dtype=bfloat16, _experts_implementation=eager") print(f"\nDone! Dequantized {fp8_dequantized} FP8 weights, removed {fp8_scales_removed} scale tensors") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Dequantize FP8 weights to BF16") parser.add_argument("model_dir", help="Path to mixed-precision model") parser.add_argument("out_dir", help="Path to write dequantized BF16 model") args = parser.parse_args() dequantize_model(args.model_dir, args.out_dir)