Verified that our NVFP4 packing convention (odd<<4|even, round-half-to-even) matches the DeepSeek-V4 checkpoint exactly: 100% byte-identical round-trip across all tested experts. The dequantize->requantize path is lossless in practice but wasteful. Marked both prepare_weights_from_dequantized and prepare_weights_direct as deprecated in favor of prepare_weights_from_stacked which loads checkpoint FP4 bytes directly via .view(). Also added test_fp4_roundtrip.py for future reference.
147 lines
5.5 KiB
Python
147 lines
5.5 KiB
Python
"""Test: Check if dequantize→requantize preserves checkpoint FP4 weights.
|
|
|
|
If the rounding convention differs (round-half-up vs round-half-to-even),
|
|
the FP4 buckets shift and accumulated error across 1.6T weights matters.
|
|
|
|
This test:
|
|
1. Loads a single expert's checkpoint FP4 weights
|
|
2. Dequantizes them to BF16 (using the checkpoint's scales/gs)
|
|
3. Re-quantizes BF16 → FP4 (using our quantize_weight_to_nvfp4)
|
|
4. Dequantizes the re-quantized weights back to BF16
|
|
5. Compares: ||W_ours - W_checkpoint||_F / ||W_checkpoint||_F
|
|
|
|
If > 1e-3, there's a rounding mismatch that matters.
|
|
"""
|
|
import sys
|
|
import torch
|
|
|
|
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
|
|
"""Dequantize NVFP4 weight tensor to BF16.
|
|
|
|
Args:
|
|
packed_uint8: (N, K_packed) uint8 — packed FP4 bytes
|
|
scale_e4m3: (N, K_sf) float8_e4m3fn — block scales
|
|
global_scale: float32 scalar
|
|
|
|
Returns:
|
|
(N, K_original) BF16 weight matrix
|
|
"""
|
|
# Unpack FP4 → BF16
|
|
raw = packed_uint8.view(torch.uint8)
|
|
low = (raw & 0x0F).to(torch.int8) # even elements
|
|
high = ((raw >> 4) & 0x0F).to(torch.int8) # odd elements
|
|
|
|
# E2M1 magnitudes: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
|
e2m1_values = torch.tensor([0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
|
|
dtype=torch.float32, device=raw.device)
|
|
|
|
# Sign: bit 3 = 1 means negative
|
|
low_sign = (low >> 3).bool()
|
|
low_idx = (low & 0x07)
|
|
high_sign = (high >> 3).bool()
|
|
high_idx = (high & 0x07)
|
|
|
|
low_mag = e2m1_values[low_idx.long()]
|
|
high_mag = e2m1_values[high_idx.long()]
|
|
low_val = torch.where(low_sign, -low_mag, low_mag)
|
|
high_val = torch.where(high_sign, -high_mag, high_mag)
|
|
|
|
# Interleave low and high
|
|
N, K_packed = packed_uint8.shape
|
|
K = K_packed * 2
|
|
values = torch.stack([low_val, high_val], dim=-1).reshape(N, K)
|
|
|
|
# Dequantize: value * block_scale * global_scale
|
|
K_sf = scale_e4m3.shape[1]
|
|
block_size = K // K_sf # Should be 16
|
|
scale_expanded = scale_e4m3.float().unsqueeze(2).expand(-1, -1, block_size).reshape(N, K)
|
|
dequant = values * scale_expanded * global_scale
|
|
|
|
return dequant.to(torch.bfloat16)
|
|
|
|
|
|
def test_roundtrip():
|
|
from safetensors.torch import load_file
|
|
import glob
|
|
|
|
# Load checkpoint
|
|
files = sorted(glob.glob('/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4/*.safetensors'))
|
|
if not files:
|
|
print("No safetensor files found")
|
|
return
|
|
|
|
# Find a file with expert 0 weights
|
|
d = load_file(files[0])
|
|
|
|
# Test gate_proj of expert 0
|
|
gate_w = d['model.layers.0.mlp.experts.0.gate_proj.weight'].cuda() # (3072, 3584) uint8
|
|
gate_sf = d['model.layers.0.mlp.experts.0.gate_proj.weight_scale'].cuda() # (3072, 448) fp8
|
|
gate_gs = d['model.layers.0.mlp.experts.0.gate_proj.weight_scale_2'].item() # float32
|
|
|
|
print(f"Checkpoint: gate_proj shape={gate_w.shape}, sf shape={gate_sf.shape}, gs={gate_gs}")
|
|
|
|
# Step 1: Dequantize checkpoint → BF16
|
|
gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs)
|
|
print(f"Dequantized: shape={gate_bf16.shape}, amax={gate_bf16.abs().amax().item():.6f}")
|
|
|
|
# Step 2: Re-quantize BF16 → FP4 using our convention
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
|
|
from cutedsl.bridge import quantize_weight_to_nvfp4
|
|
|
|
# quantize_weight_to_nvfp4 expects (K, N) where K is the packed dim
|
|
# Our gate is (3072, 7168) in BF16, so K=3072, N=7168
|
|
# But the checkpoint stores it as (3072, 3584) uint8 = (3072, 7168//2) packed
|
|
# The dequantized shape is (3072, 7168) BF16
|
|
# quantize_weight_to_nvfp4 expects (K, N) = (3072, 7168)
|
|
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(gate_bf16)
|
|
print(f"Re-quantized: fp4 shape={w_fp4.shape}, sf shape={w_sf.shape}, gs={w_gs}")
|
|
|
|
# Step 3: Dequantize the re-quantized weights
|
|
gate_bf16_ours = dequantize_nvfp4_weight(
|
|
w_fp4.view(torch.uint8) if w_fp4.dtype != torch.uint8 else w_fp4,
|
|
w_sf,
|
|
w_gs,
|
|
)
|
|
|
|
# Step 4: Compare
|
|
diff = (gate_bf16_ours - gate_bf16).float()
|
|
rel_err = diff.norm() / gate_bf16.float().norm()
|
|
max_err = diff.abs().max()
|
|
|
|
print(f"\n=== Results ===")
|
|
print(f"Relative error (Frobenius): {rel_err.item():.6f}")
|
|
print(f"Max absolute error: {max_err.item():.6f}")
|
|
print(f"Threshold: 1e-3 = {rel_err.item() > 1e-3}")
|
|
|
|
# Step 5: Compare raw FP4 bytes
|
|
our_uint8 = w_fp4.view(torch.uint8) if w_fp4.dtype != torch.uint8 else w_fp4
|
|
byte_match = (our_uint8 == gate_w).float().mean()
|
|
print(f"FP4 byte match rate: {byte_match.item():.4f}")
|
|
|
|
# Step 6: Check where they differ
|
|
if byte_match.item() < 1.0:
|
|
mismatch = (our_uint8 != gate_w)
|
|
mismatch_count = mismatch.sum().item()
|
|
total = gate_w.numel()
|
|
print(f"Byte mismatches: {mismatch_count}/{total} ({mismatch_count/total*100:.2f}%)")
|
|
|
|
# Sample some mismatches
|
|
idx = mismatch.nonzero()[:5]
|
|
for i in range(min(5, len(idx))):
|
|
r, c = idx[i].tolist()
|
|
ours = our_uint8[r, c].item()
|
|
theirs = gate_w[r, c].item()
|
|
print(f" [{r},{c}] ours=0x{ours:02x} checkpoint=0x{theirs:02x}")
|
|
|
|
# Step 7: Also compare scales
|
|
if w_sf.shape == gate_sf.shape and w_sf.dtype == gate_sf.dtype:
|
|
sf_match = (w_sf == gate_sf).float().mean()
|
|
print(f"Block scale match rate: {sf_match.item():.4f}")
|
|
|
|
print(f"\nGlobal scale: ours={w_gs:.8f}, checkpoint={gate_gs:.8f}, diff={abs(w_gs - gate_gs):.8f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_roundtrip()
|