Complete dequant script: handles INT4 experts, FP8 attention, FP8 shared experts

INT4 expert weights are packed 2-per-byte into int8 with float8_e8m0fnu
per-row 32-column block scales. Unpacking: lower nibble first, upper second.
Output dimensions are 2x the stored dimensions (e.g. [3072,3584] → [3072,7168]).

Also adds progress output with ETA per shard so screen sessions stay alive.
This commit is contained in:
2026-05-08 01:39:50 +00:00
parent cbfc5a9afb
commit db6beb5b76

View File

@@ -1,115 +1,218 @@
#!/usr/bin/env python3
"""
Dequantize FP8 weights to BF16 in-place for DeepSeek V4 Pro mixed-precision model.
Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16.
For each FP8 weight tensor (float8_e4m3fn) paired with a block-wise scale
(float8_e8m0fnu), reconstructs the BF16 weight as:
bf16_weight = fp8_weight * scale_expanded
Handles ALL compressed tensor types found in the mixed-precision model:
After dequantization, FP8Linear.forward() sees element_size() > 1 and
falls back to F.linear(), avoiding the broken FP8 kernel paths on Blackwell.
1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales)
- weight × scale_expanded → BF16
- 128×128 block quantization
Preserves all BF16 and FP32 tensors unchanged.
Removes the now-unnecessary scale tensors from the output.
2. INT4 expert weights (int8 packed + float8_e8m0fnu block scales)
- Unpack 2 int4 values per int8 byte (lower nibble first, upper second)
- Dequantize: int4_signed × scale_expanded → BF16
- Per-row, 32-column block scaling
- Output dimensions are 2× the stored dimensions
3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales)
- Same as FP8 attention dequantization
After dequantization, all weights are pure BF16. FP8Linear.forward() sees
element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels
on Blackwell GPUs. The model can then be loaded by modelopt without shape
mismatches.
"""
import os, glob, json, shutil
import os, glob, json, shutil, sys, time
from safetensors import safe_open
from safetensors.torch import save_file
import torch
FP8_WEIGHT_DTYPE = torch.float8_e4m3fn
FP8_SCALE_DTYPE = torch.float8_e8m0fnu
BLOCK_SIZE = (128, 128)
BLOCK_SIZE_FP8 = (128, 128)
INT4_BLOCK_SIZE = 32 # columns per scale value for INT4 expert weights
def dequantize_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize block-wise FP8 weight to BF16.
fp8_weight: (out_features, in_features) float8_e4m3fn
scale: (out_features//128, in_features//128) float8_e8m0fnu or float32
scale: (out_features//128, in_features//128) float8_e8m0fnu
"""
scale_f32 = scale.float()
out_features, in_features = fp8_weight.shape
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE[0], dim=0).repeat_interleave(BLOCK_SIZE[1], dim=1)
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1)
scale_expanded = scale_expanded[:out_features, :in_features]
weight_bf16 = fp8_weight.float() * scale_expanded
return weight_bf16.to(torch.bfloat16)
def dequantize_int4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Dequantize INT4-packed expert weight to BF16.
INT4 values are packed 2-per-byte into int8 tensors.
Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second.
Signed int4 range: -8 to 7.
Scale is per-row with 32-column blocks (float8_e8m0fnu).
Output dimensions are 2× the stored dimensions.
int8_packed: (out_features, in_features//2) int8
scale: (out_features, in_features//32) float8_e8m0fnu
returns: (out_features, in_features) bfloat16
"""
# Unpack int4 from int8
lower = (int8_packed & 0x0F).to(torch.int8) # 0-15
upper = ((int8_packed >> 4) & 0x0F).to(torch.int8) # 0-15
# Convert unsigned to signed int4: 0-7 stay, 8-15 → -8 to -1
lower_signed = torch.where(lower > 7, lower - 16, lower).float()
upper_signed = torch.where(upper > 7, upper - 16, upper).float()
out_features = int8_packed.shape[0]
in_features_full = int8_packed.shape[1] * 2 # 2× expansion
# Expand scale: (out_features, in_features//32) → (out_features, in_features)
scale_f32 = scale.float()
scale_expanded = scale_f32.repeat_interleave(INT4_BLOCK_SIZE, dim=1)
scale_expanded = scale_expanded[:, :in_features_full]
# Interleave lower and upper nibbles
unpacked = torch.zeros(out_features, in_features_full, dtype=torch.float32)
unpacked[:, 0::2] = lower_signed
unpacked[:, 1::2] = upper_signed
# Dequantize
bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16)
return bf16_weight
def dequantize_model(model_dir: str, out_dir: str):
os.makedirs(out_dir, exist_ok=True)
# Copy non-safetensor files
print("Copying metadata files...")
for f in os.listdir(model_dir):
fp = os.path.join(model_dir, f)
if not f.endswith(".safetensors") and os.path.isfile(fp):
shutil.copy2(fp, os.path.join(out_dir, f))
print(f"Copied {f}")
print(f" Copied {f}")
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
total = len(safetensor_files)
# First pass: build map of scale_key -> weight_key
# Pattern: layers.X.attn.Y.scale -> layers.X.attn.Y.weight
scales_map = {}
print("Scanning for FP8 weight+scale pairs...")
total_shards = len(safetensor_files)
print(f"Found {total_shards} shards")
# First pass: build scale-key → weight-key mapping
# Pattern: *.scale → *.weight
print("\nScanning for weight+scale pairs...")
scale_to_weight = {} # scale_key → weight_key
for f in safetensor_files:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if key.endswith(".scale"):
weight_key = key[:-len(".scale")] + ".weight"
scales_map[weight_key] = key
print(f"Found {len(scales_map)} FP8 weight+scale pairs")
scale_to_weight[key] = weight_key
# Second pass: dequantize and save
fp8_dequantized = 0
fp8_scales_removed = 0
scale_keys_global = set(scales_map.values())
# Also find weight → scale mapping
weight_to_scale = {v: k for k, v in scale_to_weight.items()}
print(f"Found {len(scale_to_weight)} weight+scale pairs")
# Classify weights by type
int4_weight_keys = set()
fp8_weight_keys = set()
scale_keys = set(scale_to_weight.keys())
for i, f in enumerate(safetensor_files):
tensors = {}
scales_in_shard = {}
fp8_weights_in_shard = {}
for f in safetensor_files[:2]: # Sample to classify
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
if key in weight_to_scale:
t = sf.get_tensor(key)
if t.dtype == torch.int8:
int4_weight_keys.add(key)
elif t.dtype == FP8_WEIGHT_DTYPE:
fp8_weight_keys.add(key)
print(f" INT4 expert weights (packed): ~{len(int4_weight_keys)} per shard")
print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard")
# Second pass: dequantize and save
stats = {"int4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0}
start_time = time.time()
for i, f in enumerate(safetensor_files):
shard_start = time.time()
tensors = {}
scales_in_shard = {}
weights_to_dequant = {}
with safe_open(f, framework="pt") as sf:
keys = list(sf.keys())
# First: collect scales
for key in keys:
if key in scale_keys:
t = sf.get_tensor(key)
scales_in_shard[key] = t
# Second: process weights and other tensors
for key in keys:
if key in scale_keys:
continue # handled separately
t = sf.get_tensor(key)
if key in scale_keys_global:
# This is a scale tensor - save for dequantization
scales_in_shard[key] = t
elif key in scales_map and t.dtype == FP8_WEIGHT_DTYPE:
# FP8 weight that has a corresponding scale
fp8_weights_in_shard[key] = t
if key in weight_to_scale and t.dtype == torch.int8:
# INT4 packed expert weight
scale_key = weight_to_scale[key]
scale = scales_in_shard.get(scale_key)
if scale is None:
print(f" WARNING: scale {scale_key} not in same shard as {key}")
tensors[key] = t # keep as-is
continue
bf16 = dequantize_int4_weight(t, scale)
tensors[key] = bf16
stats["int4_dequantized"] += 1
del scales_in_shard[scale_key]
stats["scales_removed"] += 1
elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE:
# FP8 weight (attention or shared expert)
scale_key = weight_to_scale[key]
scale = scales_in_shard.get(scale_key)
if scale is None:
print(f" WARNING: scale {scale_key} not in same shard as {key}")
tensors[key] = t
continue
bf16 = dequantize_fp8_weight(t, scale)
tensors[key] = bf16
stats["fp8_dequantized"] += 1
del scales_in_shard[scale_key]
stats["scales_removed"] += 1
else:
# Regular tensor (BF16, FP32, or FP8 without scale) - keep as is
# Regular tensor (BF16, FP32, int64, etc.) - keep as-is
tensors[key] = t
stats["unchanged"] += 1
# Dequantize FP8 weights
for weight_key, fp8_w in fp8_weights_in_shard.items():
scale_key = scales_map[weight_key]
scale = scales_in_shard.get(scale_key)
if scale is not None:
bf16_weight = dequantize_weight(fp8_w, scale)
tensors[weight_key] = bf16_weight
fp8_dequantized += 1
del scales_in_shard[scale_key]
fp8_scales_removed += 1
else:
# Scale not in this shard (shouldn't happen but handle gracefully)
print(f"WARNING: scale {scale_key} not found in same shard as {weight_key}")
tensors[weight_key] = fp8_w # keep as-is
# Remove unused scales
for sk in scales_in_shard:
stats["scales_removed"] += 1
out_path = os.path.join(out_dir, os.path.basename(f))
save_file(tensors, out_path)
del tensors, scales_in_shard, fp8_weights_in_shard
if (i + 1) % 10 == 0 or i == total - 1:
print(f"[{i+1}/{total}] {os.path.basename(f)} (dequantized {fp8_dequantized} FP8, removed {fp8_scales_removed} scales)")
shard_time = time.time() - shard_start
elapsed = time.time() - start_time
rate = (i + 1) / elapsed if elapsed > 0 else 0
eta = (total_shards - i - 1) / rate if rate > 0 else 0
print(f"[{i+1}/{total_shards}] {os.path.basename(f)} "
f"({stats['int4_dequantized']} int4, {stats['fp8_dequantized']} fp8, "
f"{stats['scales_removed']} scales removed) "
f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]")
del tensors, scales_in_shard
# Update config
cfg_path = os.path.join(out_dir, "config.json")
@@ -120,14 +223,38 @@ def dequantize_model(model_dir: str, out_dir: str):
if "quantization_config" in cfg:
del cfg["quantization_config"]
json.dump(cfg, open(cfg_path, "w"), indent=2)
print(f"Updated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
print(f"\nDone! Dequantized {fp8_dequantized} FP8 weights, removed {fp8_scales_removed} scale tensors")
total_time = time.time() - start_time
print(f"\nDone in {total_time/60:.1f} minutes!")
print(f" INT4 expert weights dequantized: {stats['int4_dequantized']}")
print(f" FP8 weights dequantized: {stats['fp8_dequantized']}")
print(f" Scale tensors removed: {stats['scales_removed']}")
print(f" Unchanged tensors: {stats['unchanged']}")
# Verify no FP8/INT8 remaining
print("\nVerifying...")
remaining_compressed = 0
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]:
with safe_open(f, framework="pt") as sf:
for key in sf.keys():
t = sf.get_tensor(key)
if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8):
remaining_compressed += 1
if remaining_compressed <= 5:
print(f" REMAINING: {key} {t.dtype} {t.shape}")
if remaining_compressed == 0:
print(" ✅ No compressed tensors remaining - model is pure BF16!")
else:
print(f" ⚠️ {remaining_compressed} compressed tensors still present")
out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors"))
print(f"Output size: {out_size / 1e12:.2f} TB")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Dequantize FP8 weights to BF16")
parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16")
parser.add_argument("model_dir", help="Path to mixed-precision model")
parser.add_argument("out_dir", help="Path to write dequantized BF16 model")
args = parser.parse_args()