130 lines
5.9 KiB
Python
130 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Isolate NVFP4 GEMM error: compare production weight dequant vs reference.
|
|
|
|
Tests whether the issue is in:
|
|
1. Weight/scale layout conversion (make_b_k_major, swizzle)
|
|
2. Activation quantization (global_scale, block_scale)
|
|
3. The GEMM kernel itself
|
|
|
|
Strategy: bypass activation quantization by passing pre-quantized FP4 activation,
|
|
and compare against a pure weight dequant reference.
|
|
"""
|
|
import os, sys, json, math, torch, torch.nn.functional as F
|
|
from pathlib import Path
|
|
|
|
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
|
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
|
|
|
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
|
|
O, I2 = weight.shape; I = I2 * 2
|
|
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
|
|
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
|
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
|
|
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
|
|
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
|
|
s = weight_scale.float().repeat_interleave(16, 1)
|
|
if weight_scale_2 is not None: s = s * weight_scale_2.float()
|
|
return (w * s).bfloat16()
|
|
|
|
def get_nvfp4_weight(w, pfx, proj_name):
|
|
k = f"{pfx}.{proj_name}"
|
|
return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"),
|
|
w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale"))
|
|
|
|
def main():
|
|
device = "cuda:0"
|
|
torch.manual_seed(42)
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
|
|
from safetensors.torch import load_file
|
|
cdir = Path(CHECKPOINT_DIR); wmap = {}
|
|
idx = cdir / "model.safetensors.index.json"
|
|
if idx.exists():
|
|
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
|
|
shards = set(wmap.values()) if wmap else set(); all_w = {}
|
|
for sn in sorted(shards):
|
|
if (cdir / sn).exists(): all_w.update(load_file(str(cdir / sn)))
|
|
print(f"Loaded {len(all_w)} tensors")
|
|
|
|
from dsv4.layers.linear import Nvfp4Linear
|
|
from dsv4.ops.quantize import quantize_activation_nvfp4
|
|
|
|
# Test 1: BF16 input through full production path vs reference
|
|
# This tests activation quantization + GEMM + weight layout
|
|
test_layers = [0, 30, 60]
|
|
projs = ['q_a_proj', 'kv_proj']
|
|
|
|
for li in test_layers:
|
|
pfx = f"model.layers.{li}.self_attn"
|
|
for proj in projs:
|
|
weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj)
|
|
if weight is None:
|
|
print(f"L{li} {proj}: not found, skipping"); continue
|
|
|
|
weight = weight.to(device)
|
|
ws = ws.to(device)
|
|
ws2 = ws2.to(device) if ws2 is not None else None
|
|
isc = isc.to(device) if isc is not None else None
|
|
|
|
actual_out = weight.shape[0]
|
|
actual_in = weight.shape[1] * 2
|
|
|
|
# BF16 input (same as model would provide)
|
|
x = torch.randn(1, actual_in, dtype=torch.bfloat16, device=device) * 2.0
|
|
|
|
# === Test A: Full production path ===
|
|
lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=device)
|
|
lin.fp4 = [weight.view(torch.float4_e2m1fn_x2) if weight.dtype == torch.uint8 else weight]
|
|
lin.sf = [ws]
|
|
lin.gs = [1.0]
|
|
lin.ws2 = [ws2]
|
|
isc_val = isc.float().item() if isc is not None else 1.0/(6.0*448.0)
|
|
lin._activation_global_scale = isc_val
|
|
lin.finalize_weights()
|
|
|
|
prod_out = lin(x)
|
|
|
|
# === Test B: PyTorch reference (F.linear(dequant)) ===
|
|
w_ref = dequant_nvfp4(weight, ws, ws2)
|
|
ref_out = F.linear(x, w_ref)
|
|
|
|
# === Test C: Manual quantize + production GEMM (skip Nvfp4Linear wrapper) ===
|
|
# Quantize activation ourselves
|
|
x_fp4, x_sf = quantize_activation_nvfp4(x, isc_val)
|
|
|
|
cos_full = torch.nn.functional.cosine_similarity(prod_out.flatten().float(), ref_out.flatten().float(), dim=0).item()
|
|
prod_max = prod_out.abs().max().item()
|
|
ref_max = ref_out.abs().max().item()
|
|
ratio = prod_max / (ref_max + 1e-10)
|
|
|
|
# Check: does the dequantized weight match?
|
|
# After finalize_weights, the weight is in K-major + swizzled layout.
|
|
# We can't easily de-swizzle it, but we can check the GSB.
|
|
gsb = lin._gsb.item() if lin._gsb is not None else 1.0
|
|
ws2_val = ws2.float().item() if ws2 is not None else 1.0
|
|
|
|
print(f"L{li} {proj}: cos={cos_full:.6f} |prod|={prod_max:.4f} |ref|={ref_max:.4f} ratio={ratio:.4f} gsb={gsb:.6f} ws2={ws2_val:.6f} gsa={isc_val:.8f}")
|
|
|
|
# Test D: Run production GEMM with BF16 input (not FP4 quantized)
|
|
# This bypasses activation quantization entirely
|
|
# If this matches the reference, the bug is in activation quantization
|
|
# If this doesn't match, the bug is in weight layout / GEMM
|
|
|
|
# We can't easily do this with the current API, so let's do a simpler check:
|
|
# Compare the BF16 dequant weight with the production weight format
|
|
# by running the GEMM with a known-good BF16 input.
|
|
|
|
# Use a very simple input: all ones
|
|
x_ones = torch.ones(1, actual_in, dtype=torch.bfloat16, device=device)
|
|
prod_ones = lin(x_ones)
|
|
ref_ones = F.linear(x_ones, w_ref)
|
|
cos_ones = torch.nn.functional.cosine_similarity(prod_ones.flatten().float(), ref_ones.flatten().float(), dim=0).item()
|
|
print(f" all-ones: cos={cos_ones:.6f} |prod|={prod_ones.abs().max().item():.4f} |ref|={ref_ones.abs().max().item():.4f} ratio={prod_ones.abs().max().item()/(ref_ones.abs().max().item()+1e-10):.4f}")
|
|
|
|
print("\nDone.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|