Verified our FP4 dequant is byte-identical to official transformers MXFP4 implementation. Max diff = 0.0 across all values.
277 lines
11 KiB
Python
277 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16.
|
||
|
||
Handles ALL compressed tensor types found in the mixed-precision model:
|
||
|
||
1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales)
|
||
- weight × scale_expanded → BF16
|
||
- 128×128 block quantization
|
||
|
||
2. FP4 (E2M1) expert weights (int8 packed + float8_e8m0fnu block scales)
|
||
- Unpack 2 FP4 values per int8 byte (lower nibble first, upper second)
|
||
- Dequantize via E2M1 LUT lookup × scale_expanded → BF16
|
||
- Per-row, 32-column block scaling (MXFP4 microscaling format)
|
||
- Output dimensions are 2× the stored dimensions
|
||
- Verified: nibble index 0 vs 8 ratio = 0.996 (FP4 -0.0 vs +0.0),
|
||
NOT INT4 where index 8 = -8 would be rare
|
||
|
||
3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales)
|
||
- Same as FP8 attention dequantization
|
||
|
||
After dequantization, all weights are pure BF16. FP8Linear.forward() sees
|
||
element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels
|
||
on Blackwell GPUs. The model can then be loaded by modelopt without shape
|
||
mismatches.
|
||
"""
|
||
|
||
import os, glob, json, shutil, sys, time
|
||
from safetensors import safe_open
|
||
from safetensors.torch import save_file
|
||
import torch
|
||
|
||
FP8_WEIGHT_DTYPE = torch.float8_e4m3fn
|
||
FP8_SCALE_DTYPE = torch.float8_e8m0fnu
|
||
BLOCK_SIZE_FP8 = (128, 128)
|
||
FP4_BLOCK_SIZE = 32 # columns per scale value for MXFP4 expert weights
|
||
|
||
# E2M1 FP4 lookup table (MXFP4 microscaling format)
|
||
# Index 0-7: positive values (sign=0, 2-bit exp, 1-bit mantissa)
|
||
# Index 8-15: negative values (sign=1)
|
||
# Mapping: 0→0, 1→0.5, 2→1, 3→1.5, 4→2, 5→3, 6→4, 7→6
|
||
FP4_E2M1_LUT = torch.tensor([
|
||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
||
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
||
], dtype=torch.float32)
|
||
|
||
|
||
def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||
"""Dequantize block-wise FP8 weight to BF16.
|
||
|
||
fp8_weight: (out_features, in_features) float8_e4m3fn
|
||
scale: (out_features//128, in_features//128) float8_e8m0fnu
|
||
"""
|
||
scale_f32 = scale.float()
|
||
out_features, in_features = fp8_weight.shape
|
||
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1)
|
||
scale_expanded = scale_expanded[:out_features, :in_features]
|
||
weight_bf16 = fp8_weight.float() * scale_expanded
|
||
return weight_bf16.to(torch.bfloat16)
|
||
|
||
|
||
def dequantize_fp4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||
"""Dequantize MXFP4 (E2M1) expert weight to BF16.
|
||
|
||
FP4 values are packed 2-per-byte into int8 tensors.
|
||
Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second.
|
||
E2M1 format: 1 sign + 2 exponent + 1 mantissa bit.
|
||
|
||
Scale is per-row with 32-column blocks (float8_e8m0fnu, MX microscaling).
|
||
Output dimensions are 2× the stored dimensions.
|
||
|
||
int8_packed: (out_features, in_features//2) int8
|
||
scale: (out_features, in_features//32) float8_e8m0fnu
|
||
returns: (out_features, in_features) bfloat16
|
||
"""
|
||
lut = FP4_E2M1_LUT.to(int8_packed.device)
|
||
|
||
# Unpack nibble indices
|
||
lower_idx = (int8_packed & 0x0F).long() # 0-15
|
||
upper_idx = ((int8_packed >> 4) & 0x0F).long() # 0-15
|
||
|
||
# LUT lookup
|
||
lower = lut[lower_idx] # float32
|
||
upper = lut[upper_idx] # float32
|
||
|
||
out_features = int8_packed.shape[0]
|
||
in_features_full = int8_packed.shape[1] * 2 # 2× expansion
|
||
|
||
# Expand scale: (out_features, in_features//32) → (out_features, in_features)
|
||
scale_f32 = scale.float()
|
||
scale_expanded = scale_f32.repeat_interleave(FP4_BLOCK_SIZE, dim=1)
|
||
scale_expanded = scale_expanded[:, :in_features_full]
|
||
|
||
# Interleave lower and upper nibbles
|
||
unpacked = torch.empty(out_features, in_features_full, dtype=torch.float32, device=int8_packed.device)
|
||
unpacked[:, 0::2] = lower
|
||
unpacked[:, 1::2] = upper
|
||
|
||
# Dequantize: FP4 value × E8M0 scale
|
||
bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16)
|
||
return bf16_weight
|
||
|
||
|
||
def dequantize_model(model_dir: str, out_dir: str):
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
# Copy non-safetensor files
|
||
print("Copying metadata files...")
|
||
for f in os.listdir(model_dir):
|
||
fp = os.path.join(model_dir, f)
|
||
if not f.endswith(".safetensors") and os.path.isfile(fp):
|
||
shutil.copy2(fp, os.path.join(out_dir, f))
|
||
print(f" Copied {f}")
|
||
|
||
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
|
||
total_shards = len(safetensor_files)
|
||
print(f"Found {total_shards} shards")
|
||
|
||
# First pass: build scale-key → weight-key mapping
|
||
print("\nScanning for weight+scale pairs...")
|
||
scale_to_weight = {}
|
||
for f in safetensor_files:
|
||
with safe_open(f, framework="pt") as sf:
|
||
for key in sf.keys():
|
||
if key.endswith(".scale"):
|
||
weight_key = key[:-len(".scale")] + ".weight"
|
||
scale_to_weight[key] = weight_key
|
||
|
||
weight_to_scale = {v: k for k, v in scale_to_weight.items()}
|
||
print(f"Found {len(scale_to_weight)} weight+scale pairs")
|
||
|
||
# Classify weights by type (sample first 2 shards)
|
||
fp4_weight_keys = set()
|
||
fp8_weight_keys = set()
|
||
scale_keys = set(scale_to_weight.keys())
|
||
|
||
for f in safetensor_files[:2]:
|
||
with safe_open(f, framework="pt") as sf:
|
||
for key in sf.keys():
|
||
if key in weight_to_scale:
|
||
t = sf.get_tensor(key)
|
||
if t.dtype == torch.int8:
|
||
fp4_weight_keys.add(key)
|
||
elif t.dtype == FP8_WEIGHT_DTYPE:
|
||
fp8_weight_keys.add(key)
|
||
|
||
print(f" FP4 (E2M1) expert weights (packed): ~{len(fp4_weight_keys)} per shard")
|
||
print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard")
|
||
|
||
# Second pass: dequantize and save
|
||
stats = {"fp4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0}
|
||
start_time = time.time()
|
||
|
||
for i, f in enumerate(safetensor_files):
|
||
shard_start = time.time()
|
||
tensors = {}
|
||
scales_in_shard = {}
|
||
|
||
with safe_open(f, framework="pt") as sf:
|
||
keys = list(sf.keys())
|
||
|
||
# First: collect scales
|
||
for key in keys:
|
||
if key in scale_keys:
|
||
t = sf.get_tensor(key)
|
||
scales_in_shard[key] = t
|
||
|
||
# Second: process weights and other tensors
|
||
for key in keys:
|
||
if key in scale_keys:
|
||
continue # handled separately
|
||
|
||
t = sf.get_tensor(key)
|
||
|
||
if key in weight_to_scale and t.dtype == torch.int8:
|
||
# FP4 (E2M1) packed expert weight (MXFP4 microscaling)
|
||
scale_key = weight_to_scale[key]
|
||
scale = scales_in_shard.get(scale_key)
|
||
if scale is None:
|
||
print(f" WARNING: scale {scale_key} not in same shard as {key}")
|
||
tensors[key] = t # keep as-is
|
||
continue
|
||
bf16 = dequantize_fp4_weight(t, scale)
|
||
tensors[key] = bf16
|
||
stats["fp4_dequantized"] += 1
|
||
del scales_in_shard[scale_key]
|
||
stats["scales_removed"] += 1
|
||
|
||
elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE:
|
||
# FP8 weight (attention or shared expert)
|
||
scale_key = weight_to_scale[key]
|
||
scale = scales_in_shard.get(scale_key)
|
||
if scale is None:
|
||
print(f" WARNING: scale {scale_key} not in same shard as {key}")
|
||
tensors[key] = t
|
||
continue
|
||
bf16 = dequantize_fp8_weight(t, scale)
|
||
tensors[key] = bf16
|
||
stats["fp8_dequantized"] += 1
|
||
del scales_in_shard[scale_key]
|
||
stats["scales_removed"] += 1
|
||
|
||
else:
|
||
# Regular tensor (BF16, FP32, int64, etc.) - keep as-is
|
||
tensors[key] = t
|
||
stats["unchanged"] += 1
|
||
|
||
# Remove unused scales
|
||
for sk in scales_in_shard:
|
||
stats["scales_removed"] += 1
|
||
|
||
out_path = os.path.join(out_dir, os.path.basename(f))
|
||
if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
|
||
# Resume: skip already-dequantized shards
|
||
print(f"[{i+1}/{total_shards}] Skipping (already done): {os.path.basename(f)}")
|
||
del tensors, scales_in_shard
|
||
continue
|
||
save_file(tensors, out_path)
|
||
|
||
shard_time = time.time() - shard_start
|
||
elapsed = time.time() - start_time
|
||
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||
eta = (total_shards - i - 1) / rate if rate > 0 else 0
|
||
|
||
print(f"[{i+1}/{total_shards}] {os.path.basename(f)} "
|
||
f"({stats['fp4_dequantized']} fp4, {stats['fp8_dequantized']} fp8, "
|
||
f"{stats['scales_removed']} scales rm) "
|
||
f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]")
|
||
|
||
del tensors, scales_in_shard
|
||
|
||
# Update config
|
||
cfg_path = os.path.join(out_dir, "config.json")
|
||
if os.path.exists(cfg_path):
|
||
cfg = json.load(open(cfg_path))
|
||
cfg["torch_dtype"] = "bfloat16"
|
||
cfg["_experts_implementation"] = "eager"
|
||
if "quantization_config" in cfg:
|
||
del cfg["quantization_config"]
|
||
json.dump(cfg, open(cfg_path, "w"), indent=2)
|
||
print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
|
||
|
||
total_time = time.time() - start_time
|
||
print(f"\nDone in {total_time/60:.1f} minutes!")
|
||
print(f" FP4 expert weights dequantized: {stats['fp4_dequantized']}")
|
||
print(f" FP8 weights dequantized: {stats['fp8_dequantized']}")
|
||
print(f" Scale tensors removed: {stats['scales_removed']}")
|
||
print(f" Unchanged tensors: {stats['unchanged']}")
|
||
|
||
# Verify no FP8/INT8 remaining
|
||
print("\nVerifying...")
|
||
remaining_compressed = 0
|
||
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]:
|
||
with safe_open(f, framework="pt") as sf:
|
||
for key in sf.keys():
|
||
t = sf.get_tensor(key)
|
||
if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8):
|
||
remaining_compressed += 1
|
||
if remaining_compressed <= 5:
|
||
print(f" REMAINING: {key} {t.dtype} {t.shape}")
|
||
if remaining_compressed == 0:
|
||
print(" ✅ No compressed tensors remaining — model is pure BF16!")
|
||
else:
|
||
print(f" ⚠️ {remaining_compressed} compressed tensors still present")
|
||
|
||
out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors"))
|
||
print(f"Output size: {out_size / 1e12:.2f} TB")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16")
|
||
parser.add_argument("model_dir", help="Path to mixed-precision model")
|
||
parser.add_argument("out_dir", help="Path to write dequantized BF16 model")
|
||
args = parser.parse_args()
|
||
dequantize_model(args.model_dir, args.out_dir)
|