Complete dequant script: handles INT4 experts, FP8 attention, FP8 shared experts
INT4 expert weights are packed 2-per-byte into int8 with float8_e8m0fnu per-row 32-column block scales. Unpacking: lower nibble first, upper second. Output dimensions are 2x the stored dimensions (e.g. [3072,3584] → [3072,7168]). Also adds progress output with ETA per shard so screen sessions stay alive.
This commit is contained in:
@@ -1,115 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Dequantize FP8 weights to BF16 in-place for DeepSeek V4 Pro mixed-precision model.
|
||||
Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16.
|
||||
|
||||
For each FP8 weight tensor (float8_e4m3fn) paired with a block-wise scale
|
||||
(float8_e8m0fnu), reconstructs the BF16 weight as:
|
||||
bf16_weight = fp8_weight * scale_expanded
|
||||
Handles ALL compressed tensor types found in the mixed-precision model:
|
||||
|
||||
After dequantization, FP8Linear.forward() sees element_size() > 1 and
|
||||
falls back to F.linear(), avoiding the broken FP8 kernel paths on Blackwell.
|
||||
1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales)
|
||||
- weight × scale_expanded → BF16
|
||||
- 128×128 block quantization
|
||||
|
||||
Preserves all BF16 and FP32 tensors unchanged.
|
||||
Removes the now-unnecessary scale tensors from the output.
|
||||
2. INT4 expert weights (int8 packed + float8_e8m0fnu block scales)
|
||||
- Unpack 2 int4 values per int8 byte (lower nibble first, upper second)
|
||||
- Dequantize: int4_signed × scale_expanded → BF16
|
||||
- Per-row, 32-column block scaling
|
||||
- Output dimensions are 2× the stored dimensions
|
||||
|
||||
3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales)
|
||||
- Same as FP8 attention dequantization
|
||||
|
||||
After dequantization, all weights are pure BF16. FP8Linear.forward() sees
|
||||
element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels
|
||||
on Blackwell GPUs. The model can then be loaded by modelopt without shape
|
||||
mismatches.
|
||||
"""
|
||||
|
||||
import os, glob, json, shutil
|
||||
import os, glob, json, shutil, sys, time
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
|
||||
FP8_WEIGHT_DTYPE = torch.float8_e4m3fn
|
||||
FP8_SCALE_DTYPE = torch.float8_e8m0fnu
|
||||
BLOCK_SIZE = (128, 128)
|
||||
BLOCK_SIZE_FP8 = (128, 128)
|
||||
INT4_BLOCK_SIZE = 32 # columns per scale value for INT4 expert weights
|
||||
|
||||
|
||||
def dequantize_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Dequantize block-wise FP8 weight to BF16.
|
||||
|
||||
fp8_weight: (out_features, in_features) float8_e4m3fn
|
||||
scale: (out_features//128, in_features//128) float8_e8m0fnu or float32
|
||||
scale: (out_features//128, in_features//128) float8_e8m0fnu
|
||||
"""
|
||||
scale_f32 = scale.float()
|
||||
|
||||
out_features, in_features = fp8_weight.shape
|
||||
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE[0], dim=0).repeat_interleave(BLOCK_SIZE[1], dim=1)
|
||||
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1)
|
||||
scale_expanded = scale_expanded[:out_features, :in_features]
|
||||
|
||||
weight_bf16 = fp8_weight.float() * scale_expanded
|
||||
return weight_bf16.to(torch.bfloat16)
|
||||
|
||||
|
||||
def dequantize_int4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Dequantize INT4-packed expert weight to BF16.
|
||||
|
||||
INT4 values are packed 2-per-byte into int8 tensors.
|
||||
Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second.
|
||||
Signed int4 range: -8 to 7.
|
||||
|
||||
Scale is per-row with 32-column blocks (float8_e8m0fnu).
|
||||
Output dimensions are 2× the stored dimensions.
|
||||
|
||||
int8_packed: (out_features, in_features//2) int8
|
||||
scale: (out_features, in_features//32) float8_e8m0fnu
|
||||
returns: (out_features, in_features) bfloat16
|
||||
"""
|
||||
# Unpack int4 from int8
|
||||
lower = (int8_packed & 0x0F).to(torch.int8) # 0-15
|
||||
upper = ((int8_packed >> 4) & 0x0F).to(torch.int8) # 0-15
|
||||
|
||||
# Convert unsigned to signed int4: 0-7 stay, 8-15 → -8 to -1
|
||||
lower_signed = torch.where(lower > 7, lower - 16, lower).float()
|
||||
upper_signed = torch.where(upper > 7, upper - 16, upper).float()
|
||||
|
||||
out_features = int8_packed.shape[0]
|
||||
in_features_full = int8_packed.shape[1] * 2 # 2× expansion
|
||||
|
||||
# Expand scale: (out_features, in_features//32) → (out_features, in_features)
|
||||
scale_f32 = scale.float()
|
||||
scale_expanded = scale_f32.repeat_interleave(INT4_BLOCK_SIZE, dim=1)
|
||||
scale_expanded = scale_expanded[:, :in_features_full]
|
||||
|
||||
# Interleave lower and upper nibbles
|
||||
unpacked = torch.zeros(out_features, in_features_full, dtype=torch.float32)
|
||||
unpacked[:, 0::2] = lower_signed
|
||||
unpacked[:, 1::2] = upper_signed
|
||||
|
||||
# Dequantize
|
||||
bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16)
|
||||
return bf16_weight
|
||||
|
||||
|
||||
def dequantize_model(model_dir: str, out_dir: str):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Copy non-safetensor files
|
||||
print("Copying metadata files...")
|
||||
for f in os.listdir(model_dir):
|
||||
fp = os.path.join(model_dir, f)
|
||||
if not f.endswith(".safetensors") and os.path.isfile(fp):
|
||||
shutil.copy2(fp, os.path.join(out_dir, f))
|
||||
print(f"Copied {f}")
|
||||
print(f" Copied {f}")
|
||||
|
||||
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
|
||||
total = len(safetensor_files)
|
||||
|
||||
# First pass: build map of scale_key -> weight_key
|
||||
# Pattern: layers.X.attn.Y.scale -> layers.X.attn.Y.weight
|
||||
scales_map = {}
|
||||
|
||||
print("Scanning for FP8 weight+scale pairs...")
|
||||
total_shards = len(safetensor_files)
|
||||
print(f"Found {total_shards} shards")
|
||||
|
||||
# First pass: build scale-key → weight-key mapping
|
||||
# Pattern: *.scale → *.weight
|
||||
print("\nScanning for weight+scale pairs...")
|
||||
scale_to_weight = {} # scale_key → weight_key
|
||||
for f in safetensor_files:
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
if key.endswith(".scale"):
|
||||
weight_key = key[:-len(".scale")] + ".weight"
|
||||
scales_map[weight_key] = key
|
||||
print(f"Found {len(scales_map)} FP8 weight+scale pairs")
|
||||
scale_to_weight[key] = weight_key
|
||||
|
||||
# Second pass: dequantize and save
|
||||
fp8_dequantized = 0
|
||||
fp8_scales_removed = 0
|
||||
scale_keys_global = set(scales_map.values())
|
||||
# Also find weight → scale mapping
|
||||
weight_to_scale = {v: k for k, v in scale_to_weight.items()}
|
||||
print(f"Found {len(scale_to_weight)} weight+scale pairs")
|
||||
|
||||
# Classify weights by type
|
||||
int4_weight_keys = set()
|
||||
fp8_weight_keys = set()
|
||||
scale_keys = set(scale_to_weight.keys())
|
||||
|
||||
for i, f in enumerate(safetensor_files):
|
||||
tensors = {}
|
||||
scales_in_shard = {}
|
||||
fp8_weights_in_shard = {}
|
||||
|
||||
for f in safetensor_files[:2]: # Sample to classify
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
if key in weight_to_scale:
|
||||
t = sf.get_tensor(key)
|
||||
if t.dtype == torch.int8:
|
||||
int4_weight_keys.add(key)
|
||||
elif t.dtype == FP8_WEIGHT_DTYPE:
|
||||
fp8_weight_keys.add(key)
|
||||
|
||||
print(f" INT4 expert weights (packed): ~{len(int4_weight_keys)} per shard")
|
||||
print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard")
|
||||
|
||||
# Second pass: dequantize and save
|
||||
stats = {"int4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0}
|
||||
start_time = time.time()
|
||||
|
||||
for i, f in enumerate(safetensor_files):
|
||||
shard_start = time.time()
|
||||
tensors = {}
|
||||
scales_in_shard = {}
|
||||
weights_to_dequant = {}
|
||||
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
keys = list(sf.keys())
|
||||
|
||||
# First: collect scales
|
||||
for key in keys:
|
||||
if key in scale_keys:
|
||||
t = sf.get_tensor(key)
|
||||
scales_in_shard[key] = t
|
||||
|
||||
# Second: process weights and other tensors
|
||||
for key in keys:
|
||||
if key in scale_keys:
|
||||
continue # handled separately
|
||||
|
||||
t = sf.get_tensor(key)
|
||||
|
||||
if key in scale_keys_global:
|
||||
# This is a scale tensor - save for dequantization
|
||||
scales_in_shard[key] = t
|
||||
elif key in scales_map and t.dtype == FP8_WEIGHT_DTYPE:
|
||||
# FP8 weight that has a corresponding scale
|
||||
fp8_weights_in_shard[key] = t
|
||||
if key in weight_to_scale and t.dtype == torch.int8:
|
||||
# INT4 packed expert weight
|
||||
scale_key = weight_to_scale[key]
|
||||
scale = scales_in_shard.get(scale_key)
|
||||
if scale is None:
|
||||
print(f" WARNING: scale {scale_key} not in same shard as {key}")
|
||||
tensors[key] = t # keep as-is
|
||||
continue
|
||||
bf16 = dequantize_int4_weight(t, scale)
|
||||
tensors[key] = bf16
|
||||
stats["int4_dequantized"] += 1
|
||||
del scales_in_shard[scale_key]
|
||||
stats["scales_removed"] += 1
|
||||
|
||||
elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE:
|
||||
# FP8 weight (attention or shared expert)
|
||||
scale_key = weight_to_scale[key]
|
||||
scale = scales_in_shard.get(scale_key)
|
||||
if scale is None:
|
||||
print(f" WARNING: scale {scale_key} not in same shard as {key}")
|
||||
tensors[key] = t
|
||||
continue
|
||||
bf16 = dequantize_fp8_weight(t, scale)
|
||||
tensors[key] = bf16
|
||||
stats["fp8_dequantized"] += 1
|
||||
del scales_in_shard[scale_key]
|
||||
stats["scales_removed"] += 1
|
||||
|
||||
else:
|
||||
# Regular tensor (BF16, FP32, or FP8 without scale) - keep as is
|
||||
# Regular tensor (BF16, FP32, int64, etc.) - keep as-is
|
||||
tensors[key] = t
|
||||
stats["unchanged"] += 1
|
||||
|
||||
# Dequantize FP8 weights
|
||||
for weight_key, fp8_w in fp8_weights_in_shard.items():
|
||||
scale_key = scales_map[weight_key]
|
||||
scale = scales_in_shard.get(scale_key)
|
||||
if scale is not None:
|
||||
bf16_weight = dequantize_weight(fp8_w, scale)
|
||||
tensors[weight_key] = bf16_weight
|
||||
fp8_dequantized += 1
|
||||
del scales_in_shard[scale_key]
|
||||
fp8_scales_removed += 1
|
||||
else:
|
||||
# Scale not in this shard (shouldn't happen but handle gracefully)
|
||||
print(f"WARNING: scale {scale_key} not found in same shard as {weight_key}")
|
||||
tensors[weight_key] = fp8_w # keep as-is
|
||||
# Remove unused scales
|
||||
for sk in scales_in_shard:
|
||||
stats["scales_removed"] += 1
|
||||
|
||||
out_path = os.path.join(out_dir, os.path.basename(f))
|
||||
save_file(tensors, out_path)
|
||||
del tensors, scales_in_shard, fp8_weights_in_shard
|
||||
|
||||
if (i + 1) % 10 == 0 or i == total - 1:
|
||||
print(f"[{i+1}/{total}] {os.path.basename(f)} (dequantized {fp8_dequantized} FP8, removed {fp8_scales_removed} scales)")
|
||||
shard_time = time.time() - shard_start
|
||||
elapsed = time.time() - start_time
|
||||
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (total_shards - i - 1) / rate if rate > 0 else 0
|
||||
|
||||
print(f"[{i+1}/{total_shards}] {os.path.basename(f)} "
|
||||
f"({stats['int4_dequantized']} int4, {stats['fp8_dequantized']} fp8, "
|
||||
f"{stats['scales_removed']} scales removed) "
|
||||
f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]")
|
||||
|
||||
del tensors, scales_in_shard
|
||||
|
||||
# Update config
|
||||
cfg_path = os.path.join(out_dir, "config.json")
|
||||
@@ -120,14 +223,38 @@ def dequantize_model(model_dir: str, out_dir: str):
|
||||
if "quantization_config" in cfg:
|
||||
del cfg["quantization_config"]
|
||||
json.dump(cfg, open(cfg_path, "w"), indent=2)
|
||||
print(f"Updated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
|
||||
print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
|
||||
|
||||
print(f"\nDone! Dequantized {fp8_dequantized} FP8 weights, removed {fp8_scales_removed} scale tensors")
|
||||
total_time = time.time() - start_time
|
||||
print(f"\nDone in {total_time/60:.1f} minutes!")
|
||||
print(f" INT4 expert weights dequantized: {stats['int4_dequantized']}")
|
||||
print(f" FP8 weights dequantized: {stats['fp8_dequantized']}")
|
||||
print(f" Scale tensors removed: {stats['scales_removed']}")
|
||||
print(f" Unchanged tensors: {stats['unchanged']}")
|
||||
|
||||
# Verify no FP8/INT8 remaining
|
||||
print("\nVerifying...")
|
||||
remaining_compressed = 0
|
||||
for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]:
|
||||
with safe_open(f, framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
t = sf.get_tensor(key)
|
||||
if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8):
|
||||
remaining_compressed += 1
|
||||
if remaining_compressed <= 5:
|
||||
print(f" REMAINING: {key} {t.dtype} {t.shape}")
|
||||
if remaining_compressed == 0:
|
||||
print(" ✅ No compressed tensors remaining - model is pure BF16!")
|
||||
else:
|
||||
print(f" ⚠️ {remaining_compressed} compressed tensors still present")
|
||||
|
||||
out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors"))
|
||||
print(f"Output size: {out_size / 1e12:.2f} TB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Dequantize FP8 weights to BF16")
|
||||
parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16")
|
||||
parser.add_argument("model_dir", help="Path to mixed-precision model")
|
||||
parser.add_argument("out_dir", help="Path to write dequantized BF16 model")
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user