158 lines
6.1 KiB
Python
158 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Test fused RMSNorm + NVFP4 quantize kernel.
|
|
|
|
Validates the fused kernel against the reference (unfused) path:
|
|
Reference: rmsnorm(x, weight) → quantize_nvfp4_gpu_fused(rmsnormed)
|
|
Fused: rmsnorm_quantize_nvfp4(x, weight, eps, divisor) → (fp4, sf, gsa, inv_rms)
|
|
|
|
Tests at production shapes: (1, 7168) for decode, (8, 7168) for prefill.
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
def rmsnorm_ref(x, weight, eps=1e-6):
|
|
"""PyTorch reference RMSNorm."""
|
|
xf = x.float()
|
|
rms = xf.pow(2).mean(dim=-1, keepdim=True).add(eps).rsqrt()
|
|
return xf * rms * weight.float()
|
|
|
|
|
|
def main():
|
|
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused, rmsnorm_quantize_nvfp4, dequantize_nvfp4
|
|
|
|
device = "cuda"
|
|
torch.manual_seed(42)
|
|
|
|
DIVISOR = 6.0 * 448.0 # same as production
|
|
EPS = 1e-6
|
|
HIDDEN = 7168 # production hidden_size
|
|
|
|
test_configs = [
|
|
(1, HIDDEN, "decode T=1"),
|
|
(8, HIDDEN, "prefill T=8"),
|
|
(128, HIDDEN, "prefill T=128"),
|
|
]
|
|
|
|
all_pass = True
|
|
for M, N, label in test_configs:
|
|
print(f"\n=== Test: {label} (M={M}, N={N}) ===")
|
|
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
|
|
weight = torch.randn(N, dtype=torch.float32, device=device)
|
|
# Ensure weight is positive-ish (norm weights are near 1.0 in practice)
|
|
weight = weight * 0.1 + 1.0
|
|
|
|
# Reference path: unfused
|
|
x_normed = rmsnorm_ref(x, weight, EPS) # (M, N) FP32
|
|
x_normed_bf16 = x_normed.bfloat16()
|
|
|
|
# Unfused quantize
|
|
x_fp4_ref, x_sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(x_normed_bf16, DIVISOR)
|
|
|
|
# Fused path
|
|
x_fp4_fused, x_sf_fused, gsa_fused, inv_rms_fused = rmsnorm_quantize_nvfp4(
|
|
x, weight, EPS, DIVISOR
|
|
)
|
|
|
|
# Validate shapes
|
|
assert x_fp4_fused.shape == (M, N // 2), f"fp4 shape: {x_fp4_fused.shape} != {(M, N//2)}"
|
|
assert x_sf_fused.shape == (M, N // 16), f"sf shape: {x_sf_fused.shape} != {(M, N//16)}"
|
|
assert gsa_fused.shape == (M,), f"gsa shape: {gsa_fused.shape} != {(M,)}"
|
|
assert inv_rms_fused.shape == (M,), f"inv_rms shape: {inv_rms_fused.shape} != {(M,)}"
|
|
|
|
# Validate gsa matches
|
|
# gsa is per-row; unfused computes a scalar from the whole tensor
|
|
# For T=1 they should match exactly. For T>1, fused is per-row, unfused is scalar.
|
|
if M == 1:
|
|
gsa_diff = (gsa_fused[0] - gsa_ref[0]).abs().item()
|
|
print(f" gsa: fused={gsa_fused[0].item():.6f} ref={gsa_ref[0].item():.6f} diff={gsa_diff:.2e}")
|
|
if gsa_diff > 1e-3:
|
|
print(f" FAIL: gsa mismatch")
|
|
all_pass = False
|
|
|
|
# Dequantize and compare end-to-end
|
|
dq_fused = dequantize_nvfp4(x_fp4_fused, x_sf_fused, gsa_fused)
|
|
dq_ref = dequantize_nvfp4(x_fp4_ref, x_sf_ref, gsa_ref)
|
|
|
|
# Compute cosine similarity of dequantized outputs
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
dq_fused.float().flatten().unsqueeze(0),
|
|
dq_ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Dequant cosine (fused vs unfused): {cos:.6f}")
|
|
|
|
# Also compare against the true RMSNorm output
|
|
cos_vs_ref = torch.nn.functional.cosine_similarity(
|
|
dq_fused.float().flatten().unsqueeze(0),
|
|
x_normed.flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" vs true RMSNorm: {cos_vs_ref:.6f}")
|
|
|
|
if cos < 0.995:
|
|
print(f" FAIL: dequant cosine too low ({cos:.6f})")
|
|
all_pass = False
|
|
if cos_vs_ref < 0.990:
|
|
print(f" FAIL: vs true RMSNorm cosine too low ({cos_vs_ref:.6f})")
|
|
all_pass = False
|
|
# Note: fused uses per-row gsa, unfused uses scalar gsa.
|
|
# Per-row gsa is MORE correct. Small cosine diff (0.996-0.999) is expected
|
|
# because the quantization scaling differs slightly between the two paths.
|
|
# The key metric is cos_vs_ref (vs true RMSNorm) which is ~0.994 for both.
|
|
|
|
# Test: unweighted RMSNorm (for q_b norm and kv norm)
|
|
print("\n=== Test: unweighted RMSNorm (weight=1.0) ===")
|
|
for M, N, label in [(1, HIDDEN, "decode"), (8, HIDDEN, "prefill")]:
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
|
|
weight_ones = torch.ones(N, dtype=torch.float32, device=device)
|
|
|
|
x_normed = rmsnorm_ref(x, weight_ones, EPS).bfloat16()
|
|
x_fp4_ref, x_sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(x_normed, DIVISOR)
|
|
|
|
x_fp4_fused, x_sf_fused, gsa_fused, inv_rms_fused = rmsnorm_quantize_nvfp4(
|
|
x, weight_ones, EPS, DIVISOR
|
|
)
|
|
|
|
dq_fused = dequantize_nvfp4(x_fp4_fused, x_sf_fused, gsa_fused)
|
|
dq_ref = dequantize_nvfp4(x_fp4_ref, x_sf_ref, gsa_ref)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
dq_fused.float().flatten().unsqueeze(0),
|
|
dq_ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" {label}: cos={cos:.6f}")
|
|
if cos < 0.995:
|
|
print(f" FAIL")
|
|
all_pass = False
|
|
|
|
# Test: validate gsa against per-row reference
|
|
print("\n=== Test: gsa per-row validation ===")
|
|
for M, N, label in [(1, HIDDEN, "decode"), (8, HIDDEN, "prefill")]:
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
|
|
weight = torch.randn(N, dtype=torch.float32, device=device) * 0.1 + 1.0
|
|
|
|
x_fp4_fused, x_sf_fused, gsa_fused, inv_rms_fused = rmsnorm_quantize_nvfp4(
|
|
x, weight, EPS, DIVISOR
|
|
)
|
|
|
|
max_diff = 0.0
|
|
for row in range(M):
|
|
x_normed_row = rmsnorm_ref(x[row:row+1], weight, EPS)
|
|
ref_amax = x_normed_row.abs().max().item()
|
|
ref_gsa = max(ref_amax, 1e-8) / DIVISOR
|
|
diff = abs(gsa_fused[row].item() - ref_gsa)
|
|
max_diff = max(max_diff, diff)
|
|
|
|
print(f" {label}: max gsa diff vs reference = {max_diff:.2e}")
|
|
if max_diff > 0.1: # FP4 quantization + BF16 round-trip introduces some error
|
|
print(f" FAIL: gsa diff too large")
|
|
all_pass = False
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"{'ALL TESTS PASSED' if all_pass else 'SOME TESTS FAILED'}")
|
|
return 0 if all_pass else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|