Unlike the naive upcast, this properly dequantizes FP8 block-wise weights: bf16 = fp8_weight * scale_expanded (128x128 blocks). Also removes the now-unnecessary scale tensors and updates config. FP8Linear.forward() sees element_size() > 1 and falls back to F.linear().
135 lines
5.2 KiB
Python
135 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Dequantize FP8 weights to BF16 in-place for DeepSeek V4 Pro mixed-precision model.
|
|
|
|
For each FP8 weight tensor (float8_e4m3fn) paired with a block-wise scale
|
|
(float8_e8m0fnu), reconstructs the BF16 weight as:
|
|
bf16_weight = fp8_weight * scale_expanded
|
|
|
|
After dequantization, FP8Linear.forward() sees element_size() > 1 and
|
|
falls back to F.linear(), avoiding the broken FP8 kernel paths on Blackwell.
|
|
|
|
Preserves all BF16 and FP32 tensors unchanged.
|
|
Removes the now-unnecessary scale tensors from the output.
|
|
"""
|
|
|
|
import os, glob, json, shutil
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
import torch
|
|
|
|
FP8_WEIGHT_DTYPE = torch.float8_e4m3fn
|
|
FP8_SCALE_DTYPE = torch.float8_e8m0fnu
|
|
BLOCK_SIZE = (128, 128)
|
|
|
|
|
|
def dequantize_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
"""Dequantize block-wise FP8 weight to BF16.
|
|
|
|
fp8_weight: (out_features, in_features) float8_e4m3fn
|
|
scale: (out_features//128, in_features//128) float8_e8m0fnu or float32
|
|
"""
|
|
scale_f32 = scale.float()
|
|
|
|
out_features, in_features = fp8_weight.shape
|
|
scale_expanded = scale_f32.repeat_interleave(BLOCK_SIZE[0], dim=0).repeat_interleave(BLOCK_SIZE[1], dim=1)
|
|
scale_expanded = scale_expanded[:out_features, :in_features]
|
|
|
|
weight_bf16 = fp8_weight.float() * scale_expanded
|
|
return weight_bf16.to(torch.bfloat16)
|
|
|
|
|
|
def dequantize_model(model_dir: str, out_dir: str):
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
# Copy non-safetensor files
|
|
for f in os.listdir(model_dir):
|
|
fp = os.path.join(model_dir, f)
|
|
if not f.endswith(".safetensors") and os.path.isfile(fp):
|
|
shutil.copy2(fp, os.path.join(out_dir, f))
|
|
print(f"Copied {f}")
|
|
|
|
safetensor_files = sorted(glob.glob(os.path.join(model_dir, "*.safetensors")))
|
|
total = len(safetensor_files)
|
|
|
|
# First pass: build map of scale_key -> weight_key
|
|
# Pattern: layers.X.attn.Y.scale -> layers.X.attn.Y.weight
|
|
scales_map = {}
|
|
|
|
print("Scanning for FP8 weight+scale pairs...")
|
|
for f in safetensor_files:
|
|
with safe_open(f, framework="pt") as sf:
|
|
for key in sf.keys():
|
|
if key.endswith(".scale"):
|
|
weight_key = key[:-len(".scale")] + ".weight"
|
|
scales_map[weight_key] = key
|
|
print(f"Found {len(scales_map)} FP8 weight+scale pairs")
|
|
|
|
# Second pass: dequantize and save
|
|
fp8_dequantized = 0
|
|
fp8_scales_removed = 0
|
|
scale_keys_global = set(scales_map.values())
|
|
|
|
for i, f in enumerate(safetensor_files):
|
|
tensors = {}
|
|
scales_in_shard = {}
|
|
fp8_weights_in_shard = {}
|
|
|
|
with safe_open(f, framework="pt") as sf:
|
|
for key in sf.keys():
|
|
t = sf.get_tensor(key)
|
|
|
|
if key in scale_keys_global:
|
|
# This is a scale tensor - save for dequantization
|
|
scales_in_shard[key] = t
|
|
elif key in scales_map and t.dtype == FP8_WEIGHT_DTYPE:
|
|
# FP8 weight that has a corresponding scale
|
|
fp8_weights_in_shard[key] = t
|
|
else:
|
|
# Regular tensor (BF16, FP32, or FP8 without scale) - keep as is
|
|
tensors[key] = t
|
|
|
|
# Dequantize FP8 weights
|
|
for weight_key, fp8_w in fp8_weights_in_shard.items():
|
|
scale_key = scales_map[weight_key]
|
|
scale = scales_in_shard.get(scale_key)
|
|
if scale is not None:
|
|
bf16_weight = dequantize_weight(fp8_w, scale)
|
|
tensors[weight_key] = bf16_weight
|
|
fp8_dequantized += 1
|
|
del scales_in_shard[scale_key]
|
|
fp8_scales_removed += 1
|
|
else:
|
|
# Scale not in this shard (shouldn't happen but handle gracefully)
|
|
print(f"WARNING: scale {scale_key} not found in same shard as {weight_key}")
|
|
tensors[weight_key] = fp8_w # keep as-is
|
|
|
|
out_path = os.path.join(out_dir, os.path.basename(f))
|
|
save_file(tensors, out_path)
|
|
del tensors, scales_in_shard, fp8_weights_in_shard
|
|
|
|
if (i + 1) % 10 == 0 or i == total - 1:
|
|
print(f"[{i+1}/{total}] {os.path.basename(f)} (dequantized {fp8_dequantized} FP8, removed {fp8_scales_removed} scales)")
|
|
|
|
# Update config
|
|
cfg_path = os.path.join(out_dir, "config.json")
|
|
if os.path.exists(cfg_path):
|
|
cfg = json.load(open(cfg_path))
|
|
cfg["torch_dtype"] = "bfloat16"
|
|
cfg["_experts_implementation"] = "eager"
|
|
if "quantization_config" in cfg:
|
|
del cfg["quantization_config"]
|
|
json.dump(cfg, open(cfg_path, "w"), indent=2)
|
|
print(f"Updated config.json: torch_dtype=bfloat16, _experts_implementation=eager")
|
|
|
|
print(f"\nDone! Dequantized {fp8_dequantized} FP8 weights, removed {fp8_scales_removed} scale tensors")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Dequantize FP8 weights to BF16")
|
|
parser.add_argument("model_dir", help="Path to mixed-precision model")
|
|
parser.add_argument("out_dir", help="Path to write dequantized BF16 model")
|
|
args = parser.parse_args()
|
|
dequantize_model(args.model_dir, args.out_dir)
|