#!/usr/bin/env python3 """ Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16. Handles ALL compressed tensor types found in the mixed-precision model: 1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales) - weight × scale_expanded → BF16 - 128×128 block quantization 2. FP4 (E2M1) expert weights (int8 packed + float8_e8m0fnu block scales) - Unpack 2 FP4 values per int8 byte (lower nibble first, upper second) - Dequantize via E2M1 LUT lookup × scale_expanded → BF16 - Per-row, 32-column block scaling (MXFP4 microscaling format) - Output dimensions are 2× the stored dimensions - Verified: nibble index 0 vs 8 ratio = 0.996 (FP4 -0.0 vs +0.0), NOT INT4 where index 8 = -8 would be rare 3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales) - Same as FP8 attention dequantization After dequantization, all weights are pure BF16. FP8Linear.forward() sees element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels on Blackwell GPUs. The model can then be loaded by modelopt without shape mismatches. """ import os, glob, json, shutil, sys, time from safetensors import safe_open from safetensors.torch import save_file import torch FP8_WEIGHT_DTYPE = torch.float8_e4m3fn FP8_SCALE_DTYPE = torch.float8_e8m0fnu BLOCK_SIZE_FP8 = (128, 128) FP4_BLOCK_SIZE = 32 # columns per scale value for MXFP4 expert weights # E2M1 FP4 lookup table (MXFP4 microscaling format) # Index 0-7: positive values (sign=0, 2-bit exp, 1-bit mantissa) # Index 8-15: negative values (sign=1) # Mapping: 0→0, 1→0.5, 2→1, 3→1.5, 4→2, 5→3, 6→4, 7→6 FP4_E2M1_LUT = torch.tensor([ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ], dtype=torch.float32) def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Dequantize block-wise FP8 weight to BF16. fp8_weight: (out_features, in_features) float8_e4m3fn scale: (out_features//128, in_features//128) float8_e8m0fnu """ scale_f32 = scale.float() out_features, in_features = fp8_weight.shape scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1) scale_expanded = scale_expanded[:out_features, :in_features] weight_bf16 = fp8_weight.float() * scale_expanded return weight_bf16.to(torch.bfloat16) def dequantize_fp4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Dequantize MXFP4 (E2M1) expert weight to BF16. FP4 values are packed 2-per-byte into int8 tensors. Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second. E2M1 format: 1 sign + 2 exponent + 1 mantissa bit. Scale is per-row with 32-column blocks (float8_e8m0fnu, MX microscaling). Output dimensions are 2× the stored dimensions. int8_packed: (out_features, in_features//2) int8 scale: (out_features, in_features//32) float8_e8m0fnu returns: (out_features, in_features) bfloat16 """ lut = FP4_E2M1_LUT.to(int8_packed.device) # Unpack nibble indices lower_idx = (int8_packed & 0x0F).long() # 0-15 upper_idx = ((int8_packed >> 4) & 0x0F).long() # 0-15 # LUT lookup lower = lut[lower_idx] # float32 upper = lut[upper_idx] # float32 out_features = int8_packed.shape[0] in_features_full = int8_packed.shape[1] * 2 # 2× expansion # Expand scale: (out_features, in_features//32) → (out_features, in_features) scale_f32 = scale.float() scale_expanded = scale_f32.repeat_interleave(FP4_BLOCK_SIZE, dim=1) scale_expanded = scale_expanded[:, :in_features_full] # Interleave lower and upper nibbles unpacked = torch.empty(out_features, in_features_full, dtype=torch.float32, device=int8_packed.device) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper # Dequantize: FP4 value × E8M0 scale bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16) return bf16_weight def dequantize_model(model_dir: str, out_dir: str): os.makedirs(out_dir, exist_ok=True) # Copy non-safetensor files print("Copying metadata files...") for f in os.listdir(model_dir): fp = os.path.join(model_dir, f) if not f.endswith(".safetensors") and os.path.isfile(fp): shutil.copy2(fp, os.path.join(out_dir, f)) print(f" Copied {f}") safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))) total_shards = len(safetensor_files) print(f"Found {total_shards} shards") # First pass: build scale-key → weight-key mapping print("\nScanning for weight+scale pairs...") scale_to_weight = {} for f in safetensor_files: with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key.endswith(".scale"): weight_key = key[:-len(".scale")] + ".weight" scale_to_weight[key] = weight_key weight_to_scale = {v: k for k, v in scale_to_weight.items()} print(f"Found {len(scale_to_weight)} weight+scale pairs") # Classify weights by type (sample first 2 shards) fp4_weight_keys = set() fp8_weight_keys = set() scale_keys = set(scale_to_weight.keys()) for f in safetensor_files[:2]: with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key in weight_to_scale: t = sf.get_tensor(key) if t.dtype == torch.int8: fp4_weight_keys.add(key) elif t.dtype == FP8_WEIGHT_DTYPE: fp8_weight_keys.add(key) print(f" FP4 (E2M1) expert weights (packed): ~{len(fp4_weight_keys)} per shard") print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard") # Second pass: dequantize and save stats = {"fp4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0} start_time = time.time() for i, f in enumerate(safetensor_files): shard_start = time.time() tensors = {} scales_in_shard = {} with safe_open(f, framework="pt") as sf: keys = list(sf.keys()) # First: collect scales for key in keys: if key in scale_keys: t = sf.get_tensor(key) scales_in_shard[key] = t # Second: process weights and other tensors for key in keys: if key in scale_keys: continue # handled separately t = sf.get_tensor(key) if key in weight_to_scale and t.dtype == torch.int8: # FP4 (E2M1) packed expert weight (MXFP4 microscaling) scale_key = weight_to_scale[key] scale = scales_in_shard.get(scale_key) if scale is None: print(f" WARNING: scale {scale_key} not in same shard as {key}") tensors[key] = t # keep as-is continue bf16 = dequantize_fp4_weight(t, scale) tensors[key] = bf16 stats["fp4_dequantized"] += 1 del scales_in_shard[scale_key] stats["scales_removed"] += 1 elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE: # FP8 weight (attention or shared expert) scale_key = weight_to_scale[key] scale = scales_in_shard.get(scale_key) if scale is None: print(f" WARNING: scale {scale_key} not in same shard as {key}") tensors[key] = t continue bf16 = dequantize_fp8_weight(t, scale) tensors[key] = bf16 stats["fp8_dequantized"] += 1 del scales_in_shard[scale_key] stats["scales_removed"] += 1 else: # Regular tensor (BF16, FP32, int64, etc.) - keep as-is tensors[key] = t stats["unchanged"] += 1 # Remove unused scales for sk in scales_in_shard: stats["scales_removed"] += 1 out_path = os.path.join(out_dir, os.path.basename(f)) if os.path.exists(out_path) and os.path.getsize(out_path) > 0: # Resume: skip already-dequantized shards print(f"[{i+1}/{total_shards}] Skipping (already done): {os.path.basename(f)}") del tensors, scales_in_shard continue save_file(tensors, out_path) shard_time = time.time() - shard_start elapsed = time.time() - start_time rate = (i + 1) / elapsed if elapsed > 0 else 0 eta = (total_shards - i - 1) / rate if rate > 0 else 0 print(f"[{i+1}/{total_shards}] {os.path.basename(f)} " f"({stats['fp4_dequantized']} fp4, {stats['fp8_dequantized']} fp8, " f"{stats['scales_removed']} scales rm) " f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]") del tensors, scales_in_shard # Update config cfg_path = os.path.join(out_dir, "config.json") if os.path.exists(cfg_path): cfg = json.load(open(cfg_path)) cfg["torch_dtype"] = "bfloat16" cfg["_experts_implementation"] = "eager" if "quantization_config" in cfg: del cfg["quantization_config"] json.dump(cfg, open(cfg_path, "w"), indent=2) print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager") total_time = time.time() - start_time print(f"\nDone in {total_time/60:.1f} minutes!") print(f" FP4 expert weights dequantized: {stats['fp4_dequantized']}") print(f" FP8 weights dequantized: {stats['fp8_dequantized']}") print(f" Scale tensors removed: {stats['scales_removed']}") print(f" Unchanged tensors: {stats['unchanged']}") # Verify no FP8/INT8 remaining print("\nVerifying...") remaining_compressed = 0 for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]: with safe_open(f, framework="pt") as sf: for key in sf.keys(): t = sf.get_tensor(key) if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8): remaining_compressed += 1 if remaining_compressed <= 5: print(f" REMAINING: {key} {t.dtype} {t.shape}") if remaining_compressed == 0: print(" ✅ No compressed tensors remaining — model is pure BF16!") else: print(f" ⚠️ {remaining_compressed} compressed tensors still present") out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors")) print(f"Output size: {out_size / 1e12:.2f} TB") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16") parser.add_argument("model_dir", help="Path to mixed-precision model") parser.add_argument("out_dir", help="Path to write dequantized BF16 model") args = parser.parse_args() dequantize_model(args.model_dir, args.out_dir)