#!/usr/bin/env python3 """ Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16. Handles ALL compressed tensor types found in the mixed-precision model: 1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales) - weight × scale_expanded → BF16 - 128×128 block quantization 2. INT4 expert weights (int8 packed + float8_e8m0fnu block scales) - Unpack 2 int4 values per int8 byte (lower nibble first, upper second) - Dequantize: int4_signed × scale_expanded → BF16 - Per-row, 32-column block scaling - Output dimensions are 2× the stored dimensions 3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales) - Same as FP8 attention dequantization After dequantization, all weights are pure BF16. FP8Linear.forward() sees element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels on Blackwell GPUs. The model can then be loaded by modelopt without shape mismatches. """ import os, glob, json, shutil, sys, time from safetensors import safe_open from safetensors.torch import save_file import torch FP8_WEIGHT_DTYPE = torch.float8_e4m3fn FP8_SCALE_DTYPE = torch.float8_e8m0fnu BLOCK_SIZE_FP8 = (128, 128) INT4_BLOCK_SIZE = 32 # columns per scale value for INT4 expert weights def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Dequantize block-wise FP8 weight to BF16. fp8_weight: (out_features, in_features) float8_e4m3fn scale: (out_features//128, in_features//128) float8_e8m0fnu """ scale_f32 = scale.float() out_features, in_features = fp8_weight.shape scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1) scale_expanded = scale_expanded[:out_features, :in_features] weight_bf16 = fp8_weight.float() * scale_expanded return weight_bf16.to(torch.bfloat16) def dequantize_int4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Dequantize INT4-packed expert weight to BF16. INT4 values are packed 2-per-byte into int8 tensors. Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second. Signed int4 range: -8 to 7. Scale is per-row with 32-column blocks (float8_e8m0fnu). Output dimensions are 2× the stored dimensions. int8_packed: (out_features, in_features//2) int8 scale: (out_features, in_features//32) float8_e8m0fnu returns: (out_features, in_features) bfloat16 """ # Unpack int4 from int8 lower = (int8_packed & 0x0F).to(torch.int8) # 0-15 upper = ((int8_packed >> 4) & 0x0F).to(torch.int8) # 0-15 # Convert unsigned to signed int4: 0-7 stay, 8-15 → -8 to -1 lower_signed = torch.where(lower > 7, lower - 16, lower).float() upper_signed = torch.where(upper > 7, upper - 16, upper).float() out_features = int8_packed.shape[0] in_features_full = int8_packed.shape[1] * 2 # 2× expansion # Expand scale: (out_features, in_features//32) → (out_features, in_features) scale_f32 = scale.float() scale_expanded = scale_f32.repeat_interleave(INT4_BLOCK_SIZE, dim=1) scale_expanded = scale_expanded[:, :in_features_full] # Interleave lower and upper nibbles unpacked = torch.zeros(out_features, in_features_full, dtype=torch.float32) unpacked[:, 0::2] = lower_signed unpacked[:, 1::2] = upper_signed # Dequantize bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16) return bf16_weight def dequantize_model(model_dir: str, out_dir: str): os.makedirs(out_dir, exist_ok=True) # Copy non-safetensor files print("Copying metadata files...") for f in os.listdir(model_dir): fp = os.path.join(model_dir, f) if not f.endswith(".safetensors") and os.path.isfile(fp): shutil.copy2(fp, os.path.join(out_dir, f)) print(f" Copied {f}") safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))) total_shards = len(safetensor_files) print(f"Found {total_shards} shards") # First pass: build scale-key → weight-key mapping # Pattern: *.scale → *.weight print("\nScanning for weight+scale pairs...") scale_to_weight = {} # scale_key → weight_key for f in safetensor_files: with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key.endswith(".scale"): weight_key = key[:-len(".scale")] + ".weight" scale_to_weight[key] = weight_key # Also find weight → scale mapping weight_to_scale = {v: k for k, v in scale_to_weight.items()} print(f"Found {len(scale_to_weight)} weight+scale pairs") # Classify weights by type int4_weight_keys = set() fp8_weight_keys = set() scale_keys = set(scale_to_weight.keys()) for f in safetensor_files[:2]: # Sample to classify with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key in weight_to_scale: t = sf.get_tensor(key) if t.dtype == torch.int8: int4_weight_keys.add(key) elif t.dtype == FP8_WEIGHT_DTYPE: fp8_weight_keys.add(key) print(f" INT4 expert weights (packed): ~{len(int4_weight_keys)} per shard") print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard") # Second pass: dequantize and save stats = {"int4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0} start_time = time.time() for i, f in enumerate(safetensor_files): shard_start = time.time() tensors = {} scales_in_shard = {} weights_to_dequant = {} with safe_open(f, framework="pt") as sf: keys = list(sf.keys()) # First: collect scales for key in keys: if key in scale_keys: t = sf.get_tensor(key) scales_in_shard[key] = t # Second: process weights and other tensors for key in keys: if key in scale_keys: continue # handled separately t = sf.get_tensor(key) if key in weight_to_scale and t.dtype == torch.int8: # INT4 packed expert weight scale_key = weight_to_scale[key] scale = scales_in_shard.get(scale_key) if scale is None: print(f" WARNING: scale {scale_key} not in same shard as {key}") tensors[key] = t # keep as-is continue bf16 = dequantize_int4_weight(t, scale) tensors[key] = bf16 stats["int4_dequantized"] += 1 del scales_in_shard[scale_key] stats["scales_removed"] += 1 elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE: # FP8 weight (attention or shared expert) scale_key = weight_to_scale[key] scale = scales_in_shard.get(scale_key) if scale is None: print(f" WARNING: scale {scale_key} not in same shard as {key}") tensors[key] = t continue bf16 = dequantize_fp8_weight(t, scale) tensors[key] = bf16 stats["fp8_dequantized"] += 1 del scales_in_shard[scale_key] stats["scales_removed"] += 1 else: # Regular tensor (BF16, FP32, int64, etc.) - keep as-is tensors[key] = t stats["unchanged"] += 1 # Remove unused scales for sk in scales_in_shard: stats["scales_removed"] += 1 out_path = os.path.join(out_dir, os.path.basename(f)) save_file(tensors, out_path) shard_time = time.time() - shard_start elapsed = time.time() - start_time rate = (i + 1) / elapsed if elapsed > 0 else 0 eta = (total_shards - i - 1) / rate if rate > 0 else 0 print(f"[{i+1}/{total_shards}] {os.path.basename(f)} " f"({stats['int4_dequantized']} int4, {stats['fp8_dequantized']} fp8, " f"{stats['scales_removed']} scales removed) " f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]") del tensors, scales_in_shard # Update config cfg_path = os.path.join(out_dir, "config.json") if os.path.exists(cfg_path): cfg = json.load(open(cfg_path)) cfg["torch_dtype"] = "bfloat16" cfg["_experts_implementation"] = "eager" if "quantization_config" in cfg: del cfg["quantization_config"] json.dump(cfg, open(cfg_path, "w"), indent=2) print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager") total_time = time.time() - start_time print(f"\nDone in {total_time/60:.1f} minutes!") print(f" INT4 expert weights dequantized: {stats['int4_dequantized']}") print(f" FP8 weights dequantized: {stats['fp8_dequantized']}") print(f" Scale tensors removed: {stats['scales_removed']}") print(f" Unchanged tensors: {stats['unchanged']}") # Verify no FP8/INT8 remaining print("\nVerifying...") remaining_compressed = 0 for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]: with safe_open(f, framework="pt") as sf: for key in sf.keys(): t = sf.get_tensor(key) if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8): remaining_compressed += 1 if remaining_compressed <= 5: print(f" REMAINING: {key} {t.dtype} {t.shape}") if remaining_compressed == 0: print(" ✅ No compressed tensors remaining - model is pure BF16!") else: print(f" ⚠️ {remaining_compressed} compressed tensors still present") out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors")) print(f"Output size: {out_size / 1e12:.2f} TB") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16") parser.add_argument("model_dir", help="Path to mixed-precision model") parser.add_argument("out_dir", help="Path to write dequantized BF16 model") args = parser.parse_args() dequantize_model(args.model_dir, args.out_dir)