Files
deepseek-v4-quant/scripts/dequant_fp8_to_bf16.py
biondizzle f8533197f2 Fix: expert weights are FP4 (E2M1), not INT4 - verified with nibble analysis
Nibble index 0 vs 8 ratio = 0.996 (FP4 -0.0 ≈ +0.0), NOT INT4 where -8 would be rare.
FP4 dequant uses E2M1 LUT lookup × E8M0 scale (MXFP4 microscaling).
Also adds model_opt_nvfp4_full.py for full model NVFP4 quantization.
2026-05-08 02:25:43 +00:00

272 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16.
Handles ALL compressed tensor types found in the mixed-precision model:
1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales)
- weight × scale_expanded → BF16
- 128×128 block quantization
2. FP4 (E2M1) expert weights (int8 packed + float8_e8m0fnu block scales)
- Unpack 2 FP4 values per int8 byte (lower nibble first, upper second)
- Dequantize via E2M1 LUT lookup × scale_expanded → BF16
- Per-row, 32-column block scaling (MXFP4 microscaling format)
- Output dimensions are 2× the stored dimensions
- Verified: nibble index 0 vs 8 ratio = 0.996 (FP4 -0.0 vs +0.0),
NOT INT4 where index 8 = -8 would be rare
3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales)
- Same as FP8 attention dequantization
After dequantization, all weights are pure BF16. FP8Linear.forward() sees
element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels
on Blackwell GPUs. The model can then be loaded by modelopt without shape
mismatches.
"""
import os, glob, json, shutil, sys, time
from safetensors import safe_open
from safetensors.torch import save_file
import torch
FP8_WEIGHT_DTYPE = torch.float8_e4m3fn
FP8_SCALE_DTYPE = torch.float8_e8m0fnu
BLOCK_SIZE_FP8 = (128, 128)
FP4_BLOCK_SIZE = 32 # columns per scale value for MXFP4 expert weights
# E2M1 FP4 lookup table (MXFP4 microscaling format)
# Index 0-7: positive values (sign=0, 2-bit exp, 1-bit mantissa)
# Index 8-15: negative values (sign=1)
# Mapping: 0→0, 1→0.5, 2→1, 3→1.5, 4→2, 5→3, 6→4, 7→6
FP4_E2M1_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
], dtype=torch.float32)
def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize block-wise FP8 weight to BF16.
fp8_weight: (out_features, in_features) float8_e4m3fn
scale: (out_features//128, in_features//128) float8_e8m0fnu
"""
scale_f32 = scale.float()
out_features, in_features = fp8_weight.shape
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1)
scale_expanded = scale_expanded[:out_features, :in_features]
weight_bf16 = fp8_weight.float() * scale_expanded
return weight_bf16.to(torch.bfloat16)
def dequantize_fp4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize MXFP4 (E2M1) expert weight to BF16.
FP4 values are packed 2-per-byte into int8 tensors.
Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second.
E2M1 format: 1 sign + 2 exponent + 1 mantissa bit.
Scale is per-row with 32-column blocks (float8_e8m0fnu, MX microscaling).
Output dimensions are 2× the stored dimensions.
int8_packed: (out_features, in_features//2) int8
scale: (out_features, in_features//32) float8_e8m0fnu
returns: (out_features, in_features) bfloat16
"""
lut = FP4_E2M1_LUT.to(int8_packed.device)
# Unpack nibble indices
lower_idx = (int8_packed & 0x0F).long() # 0-15
upper_idx = ((int8_packed >> 4) & 0x0F).long() # 0-15
# LUT lookup
lower = lut[lower_idx] # float32
upper = lut[upper_idx] # float32
out_features = int8_packed.shape[0]
in_features_full = int8_packed.shape[1] * 2 # 2× expansion
# Expand scale: (out_features, in_features//32) → (out_features, in_features)
scale_f32 = scale.float()
scale_expanded = scale_f32.repeat_interleave(FP4_BLOCK_SIZE, dim=1)
scale_expanded = scale_expanded[:, :in_features_full]
# Interleave lower and upper nibbles
unpacked = torch.empty(out_features, in_features_full, dtype=torch.float32, device=int8_packed.device)
unpacked[:, 0::2] = lower
unpacked[:, 1::2] = upper
# Dequantize: FP4 value × E8M0 scale
bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16)
return bf16_weight
def dequantize_model(model_dir: str, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
# Copy non-safetensor files
print("Copying metadata files...")
for f in os.listdir(model_dir):
fp = os.path.join(model_dir, f)
if not f.endswith(".safetensors") and os.path.isfile(fp):
shutil.copy2(fp, os.path.join(out_dir, f))
print(f" Copied {f}")
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
total_shards = len(safetensor_files)
print(f"Found {total_shards} shards")
# First pass: build scale-key → weight-key mapping
print("\nScanning for weight+scale pairs...")
scale_to_weight = {}
for f in safetensor_files:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if key.endswith(".scale"):
weight_key = key[:-len(".scale")] + ".weight"
scale_to_weight[key] = weight_key
weight_to_scale = {v: k for k, v in scale_to_weight.items()}
print(f"Found {len(scale_to_weight)} weight+scale pairs")
# Classify weights by type (sample first 2 shards)
fp4_weight_keys = set()
fp8_weight_keys = set()
scale_keys = set(scale_to_weight.keys())
for f in safetensor_files[:2]:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if key in weight_to_scale:
t = sf.get_tensor(key)
if t.dtype == torch.int8:
fp4_weight_keys.add(key)
elif t.dtype == FP8_WEIGHT_DTYPE:
fp8_weight_keys.add(key)
print(f" FP4 (E2M1) expert weights (packed): ~{len(fp4_weight_keys)} per shard")
print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard")
# Second pass: dequantize and save
stats = {"fp4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0}
start_time = time.time()
for i, f in enumerate(safetensor_files):
shard_start = time.time()
tensors = {}
scales_in_shard = {}
with safe_open(f, framework="pt") as sf:
keys = list(sf.keys())
# First: collect scales
for key in keys:
if key in scale_keys:
t = sf.get_tensor(key)
scales_in_shard[key] = t
# Second: process weights and other tensors
for key in keys:
if key in scale_keys:
continue # handled separately
t = sf.get_tensor(key)
if key in weight_to_scale and t.dtype == torch.int8:
# FP4 (E2M1) packed expert weight (MXFP4 microscaling)
scale_key = weight_to_scale[key]
scale = scales_in_shard.get(scale_key)
if scale is None:
print(f" WARNING: scale {scale_key} not in same shard as {key}")
tensors[key] = t # keep as-is
continue
bf16 = dequantize_fp4_weight(t, scale)
tensors[key] = bf16
stats["fp4_dequantized"] += 1
del scales_in_shard[scale_key]
stats["scales_removed"] += 1
elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE:
# FP8 weight (attention or shared expert)
scale_key = weight_to_scale[key]
scale = scales_in_shard.get(scale_key)
if scale is None:
print(f" WARNING: scale {scale_key} not in same shard as {key}")
tensors[key] = t
continue
bf16 = dequantize_fp8_weight(t, scale)
tensors[key] = bf16
stats["fp8_dequantized"] += 1
del scales_in_shard[scale_key]
stats["scales_removed"] += 1
else:
# Regular tensor (BF16, FP32, int64, etc.) - keep as-is
tensors[key] = t
stats["unchanged"] += 1
# Remove unused scales
for sk in scales_in_shard:
stats["scales_removed"] += 1
out_path = os.path.join(out_dir, os.path.basename(f))
save_file(tensors, out_path)
shard_time = time.time() - shard_start
elapsed = time.time() - start_time
rate = (i + 1) / elapsed if elapsed > 0 else 0
eta = (total_shards - i - 1) / rate if rate > 0 else 0
print(f"[{i+1}/{total_shards}] {os.path.basename(f)} "
f"({stats['fp4_dequantized']} fp4, {stats['fp8_dequantized']} fp8, "
f"{stats['scales_removed']} scales rm) "
f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]")
del tensors, scales_in_shard
# Update config
cfg_path = os.path.join(out_dir, "config.json")
if os.path.exists(cfg_path):
cfg = json.load(open(cfg_path))
cfg["torch_dtype"] = "bfloat16"
cfg["_experts_implementation"] = "eager"
if "quantization_config" in cfg:
del cfg["quantization_config"]
json.dump(cfg, open(cfg_path, "w"), indent=2)
print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
total_time = time.time() - start_time
print(f"\nDone in {total_time/60:.1f} minutes!")
print(f" FP4 expert weights dequantized: {stats['fp4_dequantized']}")
print(f" FP8 weights dequantized: {stats['fp8_dequantized']}")
print(f" Scale tensors removed: {stats['scales_removed']}")
print(f" Unchanged tensors: {stats['unchanged']}")
# Verify no FP8/INT8 remaining
print("\nVerifying...")
remaining_compressed = 0
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
t = sf.get_tensor(key)
if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8):
remaining_compressed += 1
if remaining_compressed <= 5:
print(f" REMAINING: {key} {t.dtype} {t.shape}")
if remaining_compressed == 0:
print(" ✅ No compressed tensors remaining — model is pure BF16!")
else:
print(f" ⚠️ {remaining_compressed} compressed tensors still present")
out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors"))
print(f"Output size: {out_size / 1e12:.2f} TB")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16")
parser.add_argument("model_dir", help="Path to mixed-precision model")
parser.add_argument("out_dir", help="Path to write dequantized BF16 model")
args = parser.parse_args()
dequantize_model(args.model_dir, args.out_dir)