Files
nvfp4-megamoe-kernel/tests/unit/test_fused_mhc_rmsnorm_quantize.py
biondizzle 454dbdad52 P5: Fused mHC pre_block + RMSNorm + NVFP4 quantize kernel
- fused_mhc_rmsnorm_quantize.cu: 2-kernel approach
  Kernel 1: mhc_rmsnorm_amax_gsa — bmm + RMS + amax → gsa
  Kernel 2: mhc_rmsnorm_quantize_nvfp4 — bmm + normalize + quantize
- Python bridge: mhc_rmsnorm_quantize_nvfp4() in ops/quantize.py
- Unit test: test_fused_mhc_rmsnorm_quantize.py (production shapes)
- Eliminates ~610 kernel launches per token (122 sites × 5 launches saved)
2026-06-02 16:39:42 +00:00

106 lines
3.8 KiB
Python

#!/usr/bin/env python3
"""Test fused mHC pre_block + RMSNorm + NVFP4 quantize kernel.
Validates the fused kernel against the reference (unfused) path:
Reference: mHC.pre_block(X_l) → rmsnorm(x_in, weight) → quantize_nvfp4_gpu_fused(x_normed)
Fused: mhc_rmsnorm_quantize_nvfp4(X_l, A_l, weight) → QuantizedActivation
Tests at production shapes: (1, 4, 7168) for decode, (8, 4, 7168) for prefill.
"""
import torch
import math
def rmsnorm_ref(x, weight, eps=1e-6):
"""PyTorch reference RMSNorm."""
xf = x.float()
rms = xf.pow(2).mean(dim=-1, keepdim=True).add(eps).rsqrt()
return xf * rms * weight.float()
def main():
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4
device = "cuda"
torch.manual_seed(42)
DIVISOR = 6.0 * 448.0
EPS = 1e-6
N_HC = 4
HIDDEN = 7168
test_configs = [
(1, N_HC, HIDDEN, "decode T=1"),
(8, N_HC, HIDDEN, "prefill T=8"),
(128, N_HC, HIDDEN, "prefill T=128"),
]
all_pass = True
for M, n_hc, N, label in test_configs:
print(f"\n=== Test: {label} (M={M}, n_hc={n_hc}, N={N}) ===")
# Create X_l and A_l (mHC pre_block inputs)
X_l = torch.randn(M, n_hc, N, dtype=torch.bfloat16, device=device)
A_l = torch.softmax(torch.randn(M, n_hc, dtype=torch.bfloat16, device=device), dim=-1)
norm_weight = torch.randn(N, dtype=torch.float32, device=device) * 0.1 + 1.0
# Reference path: unfused
x_in = torch.bmm(A_l.unsqueeze(1).float(), X_l.float()).squeeze(1) # (M, N) FP32
x_normed = rmsnorm_ref(x_in.bfloat16(), norm_weight, EPS)
x_fp4_ref, x_sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(x_normed.bfloat16(), DIVISOR)
# Fused path
quant = mhc_rmsnorm_quantize_nvfp4(X_l, A_l, norm_weight, EPS, DIVISOR)
# Validate shapes
assert quant.x_fp4.shape == (M, N // 2), f"fp4 shape: {quant.x_fp4.shape}"
assert quant.x_sf.shape == (M, N // 16), f"sf shape: {quant.x_sf.shape}"
assert quant.gsa.shape == (M,), f"gsa shape: {quant.gsa.shape}"
# Dequantize and compare
dq_fused = dequantize_nvfp4(quant.x_fp4, quant.x_sf, quant.gsa)
dq_ref = dequantize_nvfp4(x_fp4_ref, x_sf_ref, gsa_ref)
# Cosine similarity
cos = torch.nn.functional.cosine_similarity(
dq_fused.float().flatten().unsqueeze(0),
dq_ref.float().flatten().unsqueeze(0)
).item()
print(f" Dequant cosine (fused vs unfused): {cos:.6f}")
# Against true mHC + RMSNorm output
cos_vs_ref = torch.nn.functional.cosine_similarity(
dq_fused.float().flatten().unsqueeze(0),
x_normed.flatten().unsqueeze(0)
).item()
print(f" vs true mHC+RMSNorm: {cos_vs_ref:.6f}")
if cos < 0.995:
print(f" FAIL: dequant cosine too low ({cos:.6f})")
all_pass = False
if cos_vs_ref < 0.990:
print(f" FAIL: vs true mHC+RMSNorm cosine too low ({cos_vs_ref:.6f})")
all_pass = False
# gsa per-row validation
max_gsa_diff = 0.0
for row in range(min(M, 4)):
x_in_row = (A_l[row].unsqueeze(0).float() @ X_l[row].float()).squeeze(0) # (N,)
x_normed_row = rmsnorm_ref(x_in_row.unsqueeze(0).bfloat16(), norm_weight, EPS)
ref_amax = x_normed_row.abs().max().item()
ref_gsa = max(ref_amax, 1e-8) / DIVISOR
diff = abs(quant.gsa[row].item() - ref_gsa)
max_gsa_diff = max(max_gsa_diff, diff)
print(f" gsa per-row diff: {max_gsa_diff:.2e}")
if max_gsa_diff > 0.5:
print(f" FAIL: gsa diff too large")
all_pass = False
print(f"\n{'='*60}")
print(f"{'ALL TESTS PASSED' if all_pass else 'SOME TESTS FAILED'}")
return 0 if all_pass else 1
if __name__ == "__main__":
exit(main())