Files
deepseek-v4-quant/scripts/dequant_fp8_to_bf16.py
biondizzle db6beb5b76 Complete dequant script: handles INT4 experts, FP8 attention, FP8 shared experts
INT4 expert weights are packed 2-per-byte into int8 with float8_e8m0fnu
per-row 32-column block scales. Unpacking: lower nibble first, upper second.
Output dimensions are 2x the stored dimensions (e.g. [3072,3584] → [3072,7168]).

Also adds progress output with ETA per shard so screen sessions stay alive.
2026-05-08 01:39:50 +00:00

262 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16.
Handles ALL compressed tensor types found in the mixed-precision model:
1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales)
- weight × scale_expanded → BF16
- 128×128 block quantization
2. INT4 expert weights (int8 packed + float8_e8m0fnu block scales)
- Unpack 2 int4 values per int8 byte (lower nibble first, upper second)
- Dequantize: int4_signed × scale_expanded → BF16
- Per-row, 32-column block scaling
- Output dimensions are 2× the stored dimensions
3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales)
- Same as FP8 attention dequantization
After dequantization, all weights are pure BF16. FP8Linear.forward() sees
element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels
on Blackwell GPUs. The model can then be loaded by modelopt without shape
mismatches.
"""
import os, glob, json, shutil, sys, time
from safetensors import safe_open
from safetensors.torch import save_file
import torch
FP8_WEIGHT_DTYPE = torch.float8_e4m3fn
FP8_SCALE_DTYPE = torch.float8_e8m0fnu
BLOCK_SIZE_FP8 = (128, 128)
INT4_BLOCK_SIZE = 32 # columns per scale value for INT4 expert weights
def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize block-wise FP8 weight to BF16.
fp8_weight: (out_features, in_features) float8_e4m3fn
scale: (out_features//128, in_features//128) float8_e8m0fnu
"""
scale_f32 = scale.float()
out_features, in_features = fp8_weight.shape
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1)
scale_expanded = scale_expanded[:out_features, :in_features]
weight_bf16 = fp8_weight.float() * scale_expanded
return weight_bf16.to(torch.bfloat16)
def dequantize_int4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize INT4-packed expert weight to BF16.
INT4 values are packed 2-per-byte into int8 tensors.
Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second.
Signed int4 range: -8 to 7.
Scale is per-row with 32-column blocks (float8_e8m0fnu).
Output dimensions are 2× the stored dimensions.
int8_packed: (out_features, in_features//2) int8
scale: (out_features, in_features//32) float8_e8m0fnu
returns: (out_features, in_features) bfloat16
"""
# Unpack int4 from int8
lower = (int8_packed & 0x0F).to(torch.int8) # 0-15
upper = ((int8_packed >> 4) & 0x0F).to(torch.int8) # 0-15
# Convert unsigned to signed int4: 0-7 stay, 8-15 → -8 to -1
lower_signed = torch.where(lower > 7, lower - 16, lower).float()
upper_signed = torch.where(upper > 7, upper - 16, upper).float()
out_features = int8_packed.shape[0]
in_features_full = int8_packed.shape[1] * 2 # 2× expansion
# Expand scale: (out_features, in_features//32) → (out_features, in_features)
scale_f32 = scale.float()
scale_expanded = scale_f32.repeat_interleave(INT4_BLOCK_SIZE, dim=1)
scale_expanded = scale_expanded[:, :in_features_full]
# Interleave lower and upper nibbles
unpacked = torch.zeros(out_features, in_features_full, dtype=torch.float32)
unpacked[:, 0::2] = lower_signed
unpacked[:, 1::2] = upper_signed
# Dequantize
bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16)
return bf16_weight
def dequantize_model(model_dir: str, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
# Copy non-safetensor files
print("Copying metadata files...")
for f in os.listdir(model_dir):
fp = os.path.join(model_dir, f)
if not f.endswith(".safetensors") and os.path.isfile(fp):
shutil.copy2(fp, os.path.join(out_dir, f))
print(f" Copied {f}")
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
total_shards = len(safetensor_files)
print(f"Found {total_shards} shards")
# First pass: build scale-key → weight-key mapping
# Pattern: *.scale → *.weight
print("\nScanning for weight+scale pairs...")
scale_to_weight = {} # scale_key → weight_key
for f in safetensor_files:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if key.endswith(".scale"):
weight_key = key[:-len(".scale")] + ".weight"
scale_to_weight[key] = weight_key
# Also find weight → scale mapping
weight_to_scale = {v: k for k, v in scale_to_weight.items()}
print(f"Found {len(scale_to_weight)} weight+scale pairs")
# Classify weights by type
int4_weight_keys = set()
fp8_weight_keys = set()
scale_keys = set(scale_to_weight.keys())
for f in safetensor_files[:2]: # Sample to classify
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if key in weight_to_scale:
t = sf.get_tensor(key)
if t.dtype == torch.int8:
int4_weight_keys.add(key)
elif t.dtype == FP8_WEIGHT_DTYPE:
fp8_weight_keys.add(key)
print(f" INT4 expert weights (packed): ~{len(int4_weight_keys)} per shard")
print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard")
# Second pass: dequantize and save
stats = {"int4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0}
start_time = time.time()
for i, f in enumerate(safetensor_files):
shard_start = time.time()
tensors = {}
scales_in_shard = {}
weights_to_dequant = {}
with safe_open(f, framework="pt") as sf:
keys = list(sf.keys())
# First: collect scales
for key in keys:
if key in scale_keys:
t = sf.get_tensor(key)
scales_in_shard[key] = t
# Second: process weights and other tensors
for key in keys:
if key in scale_keys:
continue # handled separately
t = sf.get_tensor(key)
if key in weight_to_scale and t.dtype == torch.int8:
# INT4 packed expert weight
scale_key = weight_to_scale[key]
scale = scales_in_shard.get(scale_key)
if scale is None:
print(f" WARNING: scale {scale_key} not in same shard as {key}")
tensors[key] = t # keep as-is
continue
bf16 = dequantize_int4_weight(t, scale)
tensors[key] = bf16
stats["int4_dequantized"] += 1
del scales_in_shard[scale_key]
stats["scales_removed"] += 1
elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE:
# FP8 weight (attention or shared expert)
scale_key = weight_to_scale[key]
scale = scales_in_shard.get(scale_key)
if scale is None:
print(f" WARNING: scale {scale_key} not in same shard as {key}")
tensors[key] = t
continue
bf16 = dequantize_fp8_weight(t, scale)
tensors[key] = bf16
stats["fp8_dequantized"] += 1
del scales_in_shard[scale_key]
stats["scales_removed"] += 1
else:
# Regular tensor (BF16, FP32, int64, etc.) - keep as-is
tensors[key] = t
stats["unchanged"] += 1
# Remove unused scales
for sk in scales_in_shard:
stats["scales_removed"] += 1
out_path = os.path.join(out_dir, os.path.basename(f))
save_file(tensors, out_path)
shard_time = time.time() - shard_start
elapsed = time.time() - start_time
rate = (i + 1) / elapsed if elapsed > 0 else 0
eta = (total_shards - i - 1) / rate if rate > 0 else 0
print(f"[{i+1}/{total_shards}] {os.path.basename(f)} "
f"({stats['int4_dequantized']} int4, {stats['fp8_dequantized']} fp8, "
f"{stats['scales_removed']} scales removed) "
f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]")
del tensors, scales_in_shard
# Update config
cfg_path = os.path.join(out_dir, "config.json")
if os.path.exists(cfg_path):
cfg = json.load(open(cfg_path))
cfg["torch_dtype"] = "bfloat16"
cfg["_experts_implementation"] = "eager"
if "quantization_config" in cfg:
del cfg["quantization_config"]
json.dump(cfg, open(cfg_path, "w"), indent=2)
print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
total_time = time.time() - start_time
print(f"\nDone in {total_time/60:.1f} minutes!")
print(f" INT4 expert weights dequantized: {stats['int4_dequantized']}")
print(f" FP8 weights dequantized: {stats['fp8_dequantized']}")
print(f" Scale tensors removed: {stats['scales_removed']}")
print(f" Unchanged tensors: {stats['unchanged']}")
# Verify no FP8/INT8 remaining
print("\nVerifying...")
remaining_compressed = 0
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
t = sf.get_tensor(key)
if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8):
remaining_compressed += 1
if remaining_compressed <= 5:
print(f" REMAINING: {key} {t.dtype} {t.shape}")
if remaining_compressed == 0:
print(" ✅ No compressed tensors remaining - model is pure BF16!")
else:
print(f" ⚠️ {remaining_compressed} compressed tensors still present")
out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors"))
print(f"Output size: {out_size / 1e12:.2f} TB")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16")
parser.add_argument("model_dir", help="Path to mixed-precision model")
parser.add_argument("out_dir", help="Path to write dequantized BF16 model")
args = parser.parse_args()
dequantize_model(args.model_dir, args.out_dir)