- fused_mhc_rmsnorm_quantize.cu: 2-kernel approach Kernel 1: mhc_rmsnorm_amax_gsa — bmm + RMS + amax → gsa Kernel 2: mhc_rmsnorm_quantize_nvfp4 — bmm + normalize + quantize - Python bridge: mhc_rmsnorm_quantize_nvfp4() in ops/quantize.py - Unit test: test_fused_mhc_rmsnorm_quantize.py (production shapes) - Eliminates ~610 kernel launches per token (122 sites × 5 launches saved)
106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Test fused mHC pre_block + RMSNorm + NVFP4 quantize kernel.
|
|
|
|
Validates the fused kernel against the reference (unfused) path:
|
|
Reference: mHC.pre_block(X_l) → rmsnorm(x_in, weight) → quantize_nvfp4_gpu_fused(x_normed)
|
|
Fused: mhc_rmsnorm_quantize_nvfp4(X_l, A_l, weight) → QuantizedActivation
|
|
|
|
Tests at production shapes: (1, 4, 7168) for decode, (8, 4, 7168) for prefill.
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
def rmsnorm_ref(x, weight, eps=1e-6):
|
|
"""PyTorch reference RMSNorm."""
|
|
xf = x.float()
|
|
rms = xf.pow(2).mean(dim=-1, keepdim=True).add(eps).rsqrt()
|
|
return xf * rms * weight.float()
|
|
|
|
|
|
def main():
|
|
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4
|
|
|
|
device = "cuda"
|
|
torch.manual_seed(42)
|
|
|
|
DIVISOR = 6.0 * 448.0
|
|
EPS = 1e-6
|
|
N_HC = 4
|
|
HIDDEN = 7168
|
|
|
|
test_configs = [
|
|
(1, N_HC, HIDDEN, "decode T=1"),
|
|
(8, N_HC, HIDDEN, "prefill T=8"),
|
|
(128, N_HC, HIDDEN, "prefill T=128"),
|
|
]
|
|
|
|
all_pass = True
|
|
for M, n_hc, N, label in test_configs:
|
|
print(f"\n=== Test: {label} (M={M}, n_hc={n_hc}, N={N}) ===")
|
|
|
|
# Create X_l and A_l (mHC pre_block inputs)
|
|
X_l = torch.randn(M, n_hc, N, dtype=torch.bfloat16, device=device)
|
|
A_l = torch.softmax(torch.randn(M, n_hc, dtype=torch.bfloat16, device=device), dim=-1)
|
|
norm_weight = torch.randn(N, dtype=torch.float32, device=device) * 0.1 + 1.0
|
|
|
|
# Reference path: unfused
|
|
x_in = torch.bmm(A_l.unsqueeze(1).float(), X_l.float()).squeeze(1) # (M, N) FP32
|
|
x_normed = rmsnorm_ref(x_in.bfloat16(), norm_weight, EPS)
|
|
x_fp4_ref, x_sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(x_normed.bfloat16(), DIVISOR)
|
|
|
|
# Fused path
|
|
quant = mhc_rmsnorm_quantize_nvfp4(X_l, A_l, norm_weight, EPS, DIVISOR)
|
|
|
|
# Validate shapes
|
|
assert quant.x_fp4.shape == (M, N // 2), f"fp4 shape: {quant.x_fp4.shape}"
|
|
assert quant.x_sf.shape == (M, N // 16), f"sf shape: {quant.x_sf.shape}"
|
|
assert quant.gsa.shape == (M,), f"gsa shape: {quant.gsa.shape}"
|
|
|
|
# Dequantize and compare
|
|
dq_fused = dequantize_nvfp4(quant.x_fp4, quant.x_sf, quant.gsa)
|
|
dq_ref = dequantize_nvfp4(x_fp4_ref, x_sf_ref, gsa_ref)
|
|
|
|
# Cosine similarity
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
dq_fused.float().flatten().unsqueeze(0),
|
|
dq_ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Dequant cosine (fused vs unfused): {cos:.6f}")
|
|
|
|
# Against true mHC + RMSNorm output
|
|
cos_vs_ref = torch.nn.functional.cosine_similarity(
|
|
dq_fused.float().flatten().unsqueeze(0),
|
|
x_normed.flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" vs true mHC+RMSNorm: {cos_vs_ref:.6f}")
|
|
|
|
if cos < 0.995:
|
|
print(f" FAIL: dequant cosine too low ({cos:.6f})")
|
|
all_pass = False
|
|
if cos_vs_ref < 0.990:
|
|
print(f" FAIL: vs true mHC+RMSNorm cosine too low ({cos_vs_ref:.6f})")
|
|
all_pass = False
|
|
|
|
# gsa per-row validation
|
|
max_gsa_diff = 0.0
|
|
for row in range(min(M, 4)):
|
|
x_in_row = (A_l[row].unsqueeze(0).float() @ X_l[row].float()).squeeze(0) # (N,)
|
|
x_normed_row = rmsnorm_ref(x_in_row.unsqueeze(0).bfloat16(), norm_weight, EPS)
|
|
ref_amax = x_normed_row.abs().max().item()
|
|
ref_gsa = max(ref_amax, 1e-8) / DIVISOR
|
|
diff = abs(quant.gsa[row].item() - ref_gsa)
|
|
max_gsa_diff = max(max_gsa_diff, diff)
|
|
print(f" gsa per-row diff: {max_gsa_diff:.2e}")
|
|
if max_gsa_diff > 0.5:
|
|
print(f" FAIL: gsa diff too large")
|
|
all_pass = False
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"{'ALL TESTS PASSED' if all_pass else 'SOME TESTS FAILED'}")
|
|
return 0 if all_pass else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|