diff --git a/scripts/dequant_fp8_to_bf16.py b/scripts/dequant_fp8_to_bf16.py index 649db8b..5a86a06 100644 --- a/scripts/dequant_fp8_to_bf16.py +++ b/scripts/dequant_fp8_to_bf16.py @@ -1,115 +1,218 @@ #!/usr/bin/env python3 """ -Dequantize FP8 weights to BF16 in-place for DeepSeek V4 Pro mixed-precision model. +Complete dequantization of DeepSeek V4 Pro mixed-precision to pure BF16. -For each FP8 weight tensor (float8_e4m3fn) paired with a block-wise scale -(float8_e8m0fnu), reconstructs the BF16 weight as: - bf16_weight = fp8_weight * scale_expanded +Handles ALL compressed tensor types found in the mixed-precision model: -After dequantization, FP8Linear.forward() sees element_size() > 1 and -falls back to F.linear(), avoiding the broken FP8 kernel paths on Blackwell. +1. FP8 attention weights (float8_e4m3fn + float8_e8m0fnu block scales) + - weight × scale_expanded → BF16 + - 128×128 block quantization -Preserves all BF16 and FP32 tensors unchanged. -Removes the now-unnecessary scale tensors from the output. +2. INT4 expert weights (int8 packed + float8_e8m0fnu block scales) + - Unpack 2 int4 values per int8 byte (lower nibble first, upper second) + - Dequantize: int4_signed × scale_expanded → BF16 + - Per-row, 32-column block scaling + - Output dimensions are 2× the stored dimensions + +3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales) + - Same as FP8 attention dequantization + +After dequantization, all weights are pure BF16. FP8Linear.forward() sees +element_size() > 1 and falls back to F.linear(), avoiding broken FP8 kernels +on Blackwell GPUs. The model can then be loaded by modelopt without shape +mismatches. """ -import os, glob, json, shutil +import os, glob, json, shutil, sys, time from safetensors import safe_open from safetensors.torch import save_file import torch FP8_WEIGHT_DTYPE = torch.float8_e4m3fn FP8_SCALE_DTYPE = torch.float8_e8m0fnu -BLOCK_SIZE = (128, 128) +BLOCK_SIZE_FP8 = (128, 128) +INT4_BLOCK_SIZE = 32 # columns per scale value for INT4 expert weights -def dequantize_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: +def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Dequantize block-wise FP8 weight to BF16. fp8_weight: (out_features, in_features) float8_e4m3fn - scale: (out_features//128, in_features//128) float8_e8m0fnu or float32 + scale: (out_features//128, in_features//128) float8_e8m0fnu """ scale_f32 = scale.float() - out_features, in_features = fp8_weight.shape - scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE[0], dim=0).repeat_interleave(BLOCK_SIZE[1], dim=1) + scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE_FP8[0], dim=0).repeat_interleave(BLOCK_SIZE_FP8[1], dim=1) scale_expanded = scale_expanded[:out_features, :in_features] - weight_bf16 = fp8_weight.float() * scale_expanded return weight_bf16.to(torch.bfloat16) +def dequantize_int4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequantize INT4-packed expert weight to BF16. + + INT4 values are packed 2-per-byte into int8 tensors. + Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second. + Signed int4 range: -8 to 7. + + Scale is per-row with 32-column blocks (float8_e8m0fnu). + Output dimensions are 2× the stored dimensions. + + int8_packed: (out_features, in_features//2) int8 + scale: (out_features, in_features//32) float8_e8m0fnu + returns: (out_features, in_features) bfloat16 + """ + # Unpack int4 from int8 + lower = (int8_packed & 0x0F).to(torch.int8) # 0-15 + upper = ((int8_packed >> 4) & 0x0F).to(torch.int8) # 0-15 + + # Convert unsigned to signed int4: 0-7 stay, 8-15 → -8 to -1 + lower_signed = torch.where(lower > 7, lower - 16, lower).float() + upper_signed = torch.where(upper > 7, upper - 16, upper).float() + + out_features = int8_packed.shape[0] + in_features_full = int8_packed.shape[1] * 2 # 2× expansion + + # Expand scale: (out_features, in_features//32) → (out_features, in_features) + scale_f32 = scale.float() + scale_expanded = scale_f32.repeat_interleave(INT4_BLOCK_SIZE, dim=1) + scale_expanded = scale_expanded[:, :in_features_full] + + # Interleave lower and upper nibbles + unpacked = torch.zeros(out_features, in_features_full, dtype=torch.float32) + unpacked[:, 0::2] = lower_signed + unpacked[:, 1::2] = upper_signed + + # Dequantize + bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16) + return bf16_weight + + def dequantize_model(model_dir: str, out_dir: str): os.makedirs(out_dir, exist_ok=True) # Copy non-safetensor files + print("Copying metadata files...") for f in os.listdir(model_dir): fp = os.path.join(model_dir, f) if not f.endswith(".safetensors") and os.path.isfile(fp): shutil.copy2(fp, os.path.join(out_dir, f)) - print(f"Copied {f}") + print(f" Copied {f}") safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))) - total = len(safetensor_files) - - # First pass: build map of scale_key -> weight_key - # Pattern: layers.X.attn.Y.scale -> layers.X.attn.Y.weight - scales_map = {} - - print("Scanning for FP8 weight+scale pairs...") + total_shards = len(safetensor_files) + print(f"Found {total_shards} shards") + + # First pass: build scale-key → weight-key mapping + # Pattern: *.scale → *.weight + print("\nScanning for weight+scale pairs...") + scale_to_weight = {} # scale_key → weight_key for f in safetensor_files: with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key.endswith(".scale"): weight_key = key[:-len(".scale")] + ".weight" - scales_map[weight_key] = key - print(f"Found {len(scales_map)} FP8 weight+scale pairs") + scale_to_weight[key] = weight_key - # Second pass: dequantize and save - fp8_dequantized = 0 - fp8_scales_removed = 0 - scale_keys_global = set(scales_map.values()) + # Also find weight → scale mapping + weight_to_scale = {v: k for k, v in scale_to_weight.items()} + print(f"Found {len(scale_to_weight)} weight+scale pairs") + + # Classify weights by type + int4_weight_keys = set() + fp8_weight_keys = set() + scale_keys = set(scale_to_weight.keys()) - for i, f in enumerate(safetensor_files): - tensors = {} - scales_in_shard = {} - fp8_weights_in_shard = {} - + for f in safetensor_files[:2]: # Sample to classify with safe_open(f, framework="pt") as sf: for key in sf.keys(): + if key in weight_to_scale: + t = sf.get_tensor(key) + if t.dtype == torch.int8: + int4_weight_keys.add(key) + elif t.dtype == FP8_WEIGHT_DTYPE: + fp8_weight_keys.add(key) + + print(f" INT4 expert weights (packed): ~{len(int4_weight_keys)} per shard") + print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard") + + # Second pass: dequantize and save + stats = {"int4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0} + start_time = time.time() + + for i, f in enumerate(safetensor_files): + shard_start = time.time() + tensors = {} + scales_in_shard = {} + weights_to_dequant = {} + + with safe_open(f, framework="pt") as sf: + keys = list(sf.keys()) + + # First: collect scales + for key in keys: + if key in scale_keys: + t = sf.get_tensor(key) + scales_in_shard[key] = t + + # Second: process weights and other tensors + for key in keys: + if key in scale_keys: + continue # handled separately + t = sf.get_tensor(key) - if key in scale_keys_global: - # This is a scale tensor - save for dequantization - scales_in_shard[key] = t - elif key in scales_map and t.dtype == FP8_WEIGHT_DTYPE: - # FP8 weight that has a corresponding scale - fp8_weights_in_shard[key] = t + if key in weight_to_scale and t.dtype == torch.int8: + # INT4 packed expert weight + scale_key = weight_to_scale[key] + scale = scales_in_shard.get(scale_key) + if scale is None: + print(f" WARNING: scale {scale_key} not in same shard as {key}") + tensors[key] = t # keep as-is + continue + bf16 = dequantize_int4_weight(t, scale) + tensors[key] = bf16 + stats["int4_dequantized"] += 1 + del scales_in_shard[scale_key] + stats["scales_removed"] += 1 + + elif key in weight_to_scale and t.dtype == FP8_WEIGHT_DTYPE: + # FP8 weight (attention or shared expert) + scale_key = weight_to_scale[key] + scale = scales_in_shard.get(scale_key) + if scale is None: + print(f" WARNING: scale {scale_key} not in same shard as {key}") + tensors[key] = t + continue + bf16 = dequantize_fp8_weight(t, scale) + tensors[key] = bf16 + stats["fp8_dequantized"] += 1 + del scales_in_shard[scale_key] + stats["scales_removed"] += 1 + else: - # Regular tensor (BF16, FP32, or FP8 without scale) - keep as is + # Regular tensor (BF16, FP32, int64, etc.) - keep as-is tensors[key] = t + stats["unchanged"] += 1 - # Dequantize FP8 weights - for weight_key, fp8_w in fp8_weights_in_shard.items(): - scale_key = scales_map[weight_key] - scale = scales_in_shard.get(scale_key) - if scale is not None: - bf16_weight = dequantize_weight(fp8_w, scale) - tensors[weight_key] = bf16_weight - fp8_dequantized += 1 - del scales_in_shard[scale_key] - fp8_scales_removed += 1 - else: - # Scale not in this shard (shouldn't happen but handle gracefully) - print(f"WARNING: scale {scale_key} not found in same shard as {weight_key}") - tensors[weight_key] = fp8_w # keep as-is + # Remove unused scales + for sk in scales_in_shard: + stats["scales_removed"] += 1 out_path = os.path.join(out_dir, os.path.basename(f)) save_file(tensors, out_path) - del tensors, scales_in_shard, fp8_weights_in_shard - if (i + 1) % 10 == 0 or i == total - 1: - print(f"[{i+1}/{total}] {os.path.basename(f)} (dequantized {fp8_dequantized} FP8, removed {fp8_scales_removed} scales)") + shard_time = time.time() - shard_start + elapsed = time.time() - start_time + rate = (i + 1) / elapsed if elapsed > 0 else 0 + eta = (total_shards - i - 1) / rate if rate > 0 else 0 + + print(f"[{i+1}/{total_shards}] {os.path.basename(f)} " + f"({stats['int4_dequantized']} int4, {stats['fp8_dequantized']} fp8, " + f"{stats['scales_removed']} scales removed) " + f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]") + + del tensors, scales_in_shard # Update config cfg_path = os.path.join(out_dir, "config.json") @@ -120,14 +223,38 @@ def dequantize_model(model_dir: str, out_dir: str): if "quantization_config" in cfg: del cfg["quantization_config"] json.dump(cfg, open(cfg_path, "w"), indent=2) - print(f"Updated config.json: torch_dtype=bfloat16, _experts_implementation=eager") + print(f"\nUpdated config.json: torch_dtype=bfloat16, _experts_implementation=eager") - print(f"\nDone! Dequantized {fp8_dequantized} FP8 weights, removed {fp8_scales_removed} scale tensors") + total_time = time.time() - start_time + print(f"\nDone in {total_time/60:.1f} minutes!") + print(f" INT4 expert weights dequantized: {stats['int4_dequantized']}") + print(f" FP8 weights dequantized: {stats['fp8_dequantized']}") + print(f" Scale tensors removed: {stats['scales_removed']}") + print(f" Unchanged tensors: {stats['unchanged']}") + + # Verify no FP8/INT8 remaining + print("\nVerifying...") + remaining_compressed = 0 + for f in sorted(glob.glob(os.path.join(out_dir, "*.safetensors")))[:5]: + with safe_open(f, framework="pt") as sf: + for key in sf.keys(): + t = sf.get_tensor(key) + if t.dtype in (torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8): + remaining_compressed += 1 + if remaining_compressed <= 5: + print(f" REMAINING: {key} {t.dtype} {t.shape}") + if remaining_compressed == 0: + print(" ✅ No compressed tensors remaining - model is pure BF16!") + else: + print(f" ⚠️ {remaining_compressed} compressed tensors still present") + + out_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in os.listdir(out_dir) if f.endswith(".safetensors")) + print(f"Output size: {out_size / 1e12:.2f} TB") if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="Dequantize FP8 weights to BF16") + parser = argparse.ArgumentParser(description="Complete dequantization of DeepSeek V4 Pro to BF16") parser.add_argument("model_dir", help="Path to mixed-precision model") parser.add_argument("out_dir", help="Path to write dequantized BF16 model") args = parser.parse_args()