Files
nvfp4-megamoe-kernel/tests/unit/test_fp4_roundtrip.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

149 lines
5.5 KiB
Python

"""Test: Check if dequantize→requantize preserves checkpoint FP4 weights.
If the rounding convention differs (round-half-up vs round-half-to-even),
the FP4 buckets shift and accumulated error across 1.6T weights matters.
This test:
1. Loads a single expert's checkpoint FP4 weights
2. Dequantizes them to BF16 (using the checkpoint's scales/gs)
3. Re-quantizes BF16 → FP4 (using our quantize_weight_to_nvfp4)
4. Dequantizes the re-quantized weights back to BF16
5. Compares: ||W_ours - W_checkpoint||_F / ||W_checkpoint||_F
If > 1e-3, there's a rounding mismatch that matters.
"""
import sys
import torch
def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale):
"""Dequantize NVFP4 weight tensor to BF16.
Args:
packed_uint8: (N, K_packed) uint8 — packed FP4 bytes
scale_e4m3: (N, K_sf) float8_e4m3fn — block scales
global_scale: float32 scalar
Returns:
(N, K_original) BF16 weight matrix
"""
# Unpack FP4 → BF16
raw = packed_uint8.view(torch.uint8)
low = (raw & 0x0F).to(torch.int8) # even elements
high = ((raw >> 4) & 0x0F).to(torch.int8) # odd elements
# E2M1 magnitudes: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
e2m1_values = torch.tensor([0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
dtype=torch.float32, device=raw.device)
# Sign: bit 3 = 1 means negative
low_sign = (low >> 3).bool()
low_idx = (low & 0x07)
high_sign = (high >> 3).bool()
high_idx = (high & 0x07)
low_mag = e2m1_values[low_idx.long()]
high_mag = e2m1_values[high_idx.long()]
low_val = torch.where(low_sign, -low_mag, low_mag)
high_val = torch.where(high_sign, -high_mag, high_mag)
# Interleave low and high
N, K_packed = packed_uint8.shape
K = K_packed * 2
values = torch.stack([low_val, high_val], dim=-1).reshape(N, K)
# Dequantize: value * block_scale * global_scale
K_sf = scale_e4m3.shape[1]
block_size = K // K_sf # Should be 16
scale_expanded = scale_e4m3.float().unsqueeze(2).expand(-1, -1, block_size).reshape(N, K)
dequant = values * scale_expanded * global_scale
return dequant.to(torch.bfloat16)
def test_roundtrip():
from safetensors.torch import load_file
import glob
# Load checkpoint
files = sorted(glob.glob('/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4/*.safetensors'))
if not files:
print("No safetensor files found")
return
# Find a file with expert 0 weights
d = load_file(files[0])
# Test gate_proj of expert 0
gate_w = d['model.layers.0.mlp.experts.0.gate_proj.weight'].cuda() # (3072, 3584) uint8
gate_sf = d['model.layers.0.mlp.experts.0.gate_proj.weight_scale'].cuda() # (3072, 448) fp8
gate_gs = d['model.layers.0.mlp.experts.0.gate_proj.weight_scale_2'].item() # float32
print(f"Checkpoint: gate_proj shape={gate_w.shape}, sf shape={gate_sf.shape}, gs={gate_gs}")
# Step 1: Dequantize checkpoint → BF16
gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs)
print(f"Dequantized: shape={gate_bf16.shape}, amax={gate_bf16.abs().amax().item():.6f}")
# Step 2: Re-quantize BF16 → FP4 using our convention
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
from dsv4.ops.quantize import (
quantize_weight_to_nvfp4,
)
# quantize_weight_to_nvfp4 expects (K, N) where K is the packed dim
# Our gate is (3072, 7168) in BF16, so K=3072, N=7168
# But the checkpoint stores it as (3072, 3584) uint8 = (3072, 7168//2) packed
# The dequantized shape is (3072, 7168) BF16
# quantize_weight_to_nvfp4 expects (K, N) = (3072, 7168)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(gate_bf16)
print(f"Re-quantized: fp4 shape={w_fp4.shape}, sf shape={w_sf.shape}, gs={w_gs}")
# Step 3: Dequantize the re-quantized weights
gate_bf16_ours = dequantize_nvfp4_weight(
w_fp4.view(torch.uint8) if w_fp4.dtype != torch.uint8 else w_fp4,
w_sf,
w_gs,
)
# Step 4: Compare
diff = (gate_bf16_ours - gate_bf16).float()
rel_err = diff.norm() / gate_bf16.float().norm()
max_err = diff.abs().max()
print(f"\n=== Results ===")
print(f"Relative error (Frobenius): {rel_err.item():.6f}")
print(f"Max absolute error: {max_err.item():.6f}")
print(f"Threshold: 1e-3 = {rel_err.item() > 1e-3}")
# Step 5: Compare raw FP4 bytes
our_uint8 = w_fp4.view(torch.uint8) if w_fp4.dtype != torch.uint8 else w_fp4
byte_match = (our_uint8 == gate_w).float().mean()
print(f"FP4 byte match rate: {byte_match.item():.4f}")
# Step 6: Check where they differ
if byte_match.item() < 1.0:
mismatch = (our_uint8 != gate_w)
mismatch_count = mismatch.sum().item()
total = gate_w.numel()
print(f"Byte mismatches: {mismatch_count}/{total} ({mismatch_count/total*100:.2f}%)")
# Sample some mismatches
idx = mismatch.nonzero()[:5]
for i in range(min(5, len(idx))):
r, c = idx[i].tolist()
ours = our_uint8[r, c].item()
theirs = gate_w[r, c].item()
print(f" [{r},{c}] ours=0x{ours:02x} checkpoint=0x{theirs:02x}")
# Step 7: Also compare scales
if w_sf.shape == gate_sf.shape and w_sf.dtype == gate_sf.dtype:
sf_match = (w_sf == gate_sf).float().mean()
print(f"Block scale match rate: {sf_match.item():.4f}")
print(f"\nGlobal scale: ours={w_gs:.8f}, checkpoint={gate_gs:.8f}, diff={abs(w_gs - gate_gs):.8f}")
if __name__ == "__main__":
test_roundtrip()