Files
nvfp4-megamoe-kernel/tests/unit/test_fused_rmsnorm_quantize.py

158 lines
6.1 KiB
Python

#!/usr/bin/env python3
"""Test fused RMSNorm + NVFP4 quantize kernel.
Validates the fused kernel against the reference (unfused) path:
Reference: rmsnorm(x, weight) → quantize_nvfp4_gpu_fused(rmsnormed)
Fused: rmsnorm_quantize_nvfp4(x, weight, eps, divisor) → (fp4, sf, gsa, inv_rms)
Tests at production shapes: (1, 7168) for decode, (8, 7168) for prefill.
"""
import torch
import math
def rmsnorm_ref(x, weight, eps=1e-6):
"""PyTorch reference RMSNorm."""
xf = x.float()
rms = xf.pow(2).mean(dim=-1, keepdim=True).add(eps).rsqrt()
return xf * rms * weight.float()
def main():
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused, rmsnorm_quantize_nvfp4, dequantize_nvfp4
device = "cuda"
torch.manual_seed(42)
DIVISOR = 6.0 * 448.0 # same as production
EPS = 1e-6
HIDDEN = 7168 # production hidden_size
test_configs = [
(1, HIDDEN, "decode T=1"),
(8, HIDDEN, "prefill T=8"),
(128, HIDDEN, "prefill T=128"),
]
all_pass = True
for M, N, label in test_configs:
print(f"\n=== Test: {label} (M={M}, N={N}) ===")
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
weight = torch.randn(N, dtype=torch.float32, device=device)
# Ensure weight is positive-ish (norm weights are near 1.0 in practice)
weight = weight * 0.1 + 1.0
# Reference path: unfused
x_normed = rmsnorm_ref(x, weight, EPS) # (M, N) FP32
x_normed_bf16 = x_normed.bfloat16()
# Unfused quantize
x_fp4_ref, x_sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(x_normed_bf16, DIVISOR)
# Fused path
x_fp4_fused, x_sf_fused, gsa_fused, inv_rms_fused = rmsnorm_quantize_nvfp4(
x, weight, EPS, DIVISOR
)
# Validate shapes
assert x_fp4_fused.shape == (M, N // 2), f"fp4 shape: {x_fp4_fused.shape} != {(M, N//2)}"
assert x_sf_fused.shape == (M, N // 16), f"sf shape: {x_sf_fused.shape} != {(M, N//16)}"
assert gsa_fused.shape == (M,), f"gsa shape: {gsa_fused.shape} != {(M,)}"
assert inv_rms_fused.shape == (M,), f"inv_rms shape: {inv_rms_fused.shape} != {(M,)}"
# Validate gsa matches
# gsa is per-row; unfused computes a scalar from the whole tensor
# For T=1 they should match exactly. For T>1, fused is per-row, unfused is scalar.
if M == 1:
gsa_diff = (gsa_fused[0] - gsa_ref[0]).abs().item()
print(f" gsa: fused={gsa_fused[0].item():.6f} ref={gsa_ref[0].item():.6f} diff={gsa_diff:.2e}")
if gsa_diff > 1e-3:
print(f" FAIL: gsa mismatch")
all_pass = False
# Dequantize and compare end-to-end
dq_fused = dequantize_nvfp4(x_fp4_fused, x_sf_fused, gsa_fused)
dq_ref = dequantize_nvfp4(x_fp4_ref, x_sf_ref, gsa_ref)
# Compute cosine similarity of dequantized outputs
cos = torch.nn.functional.cosine_similarity(
dq_fused.float().flatten().unsqueeze(0),
dq_ref.float().flatten().unsqueeze(0)
).item()
print(f" Dequant cosine (fused vs unfused): {cos:.6f}")
# Also compare against the true RMSNorm output
cos_vs_ref = torch.nn.functional.cosine_similarity(
dq_fused.float().flatten().unsqueeze(0),
x_normed.flatten().unsqueeze(0)
).item()
print(f" vs true RMSNorm: {cos_vs_ref:.6f}")
if cos < 0.995:
print(f" FAIL: dequant cosine too low ({cos:.6f})")
all_pass = False
if cos_vs_ref < 0.990:
print(f" FAIL: vs true RMSNorm cosine too low ({cos_vs_ref:.6f})")
all_pass = False
# Note: fused uses per-row gsa, unfused uses scalar gsa.
# Per-row gsa is MORE correct. Small cosine diff (0.996-0.999) is expected
# because the quantization scaling differs slightly between the two paths.
# The key metric is cos_vs_ref (vs true RMSNorm) which is ~0.994 for both.
# Test: unweighted RMSNorm (for q_b norm and kv norm)
print("\n=== Test: unweighted RMSNorm (weight=1.0) ===")
for M, N, label in [(1, HIDDEN, "decode"), (8, HIDDEN, "prefill")]:
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
weight_ones = torch.ones(N, dtype=torch.float32, device=device)
x_normed = rmsnorm_ref(x, weight_ones, EPS).bfloat16()
x_fp4_ref, x_sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(x_normed, DIVISOR)
x_fp4_fused, x_sf_fused, gsa_fused, inv_rms_fused = rmsnorm_quantize_nvfp4(
x, weight_ones, EPS, DIVISOR
)
dq_fused = dequantize_nvfp4(x_fp4_fused, x_sf_fused, gsa_fused)
dq_ref = dequantize_nvfp4(x_fp4_ref, x_sf_ref, gsa_ref)
cos = torch.nn.functional.cosine_similarity(
dq_fused.float().flatten().unsqueeze(0),
dq_ref.float().flatten().unsqueeze(0)
).item()
print(f" {label}: cos={cos:.6f}")
if cos < 0.995:
print(f" FAIL")
all_pass = False
# Test: validate gsa against per-row reference
print("\n=== Test: gsa per-row validation ===")
for M, N, label in [(1, HIDDEN, "decode"), (8, HIDDEN, "prefill")]:
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
weight = torch.randn(N, dtype=torch.float32, device=device) * 0.1 + 1.0
x_fp4_fused, x_sf_fused, gsa_fused, inv_rms_fused = rmsnorm_quantize_nvfp4(
x, weight, EPS, DIVISOR
)
max_diff = 0.0
for row in range(M):
x_normed_row = rmsnorm_ref(x[row:row+1], weight, EPS)
ref_amax = x_normed_row.abs().max().item()
ref_gsa = max(ref_amax, 1e-8) / DIVISOR
diff = abs(gsa_fused[row].item() - ref_gsa)
max_diff = max(max_diff, diff)
print(f" {label}: max gsa diff vs reference = {max_diff:.2e}")
if max_diff > 0.1: # FP4 quantization + BF16 round-trip introduces some error
print(f" FAIL: gsa diff too large")
all_pass = False
print(f"\n{'='*60}")
print(f"{'ALL TESTS PASSED' if all_pass else 'SOME TESTS FAILED'}")
return 0 if all_pass else 1
if __name__ == "__main__":
exit(main())