81 lines
3.1 KiB
Python
81 lines
3.1 KiB
Python
"""Test: random data at small dimensions to check if non-uniform SF breaks it."""
|
|
import torch
|
|
import sys
|
|
sys.path.insert(0, 'src')
|
|
|
|
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
|
|
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
|
|
|
|
torch.manual_seed(42)
|
|
device = "cuda"
|
|
|
|
def test(M, N, K, label):
|
|
K_half = K // 2
|
|
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
|
|
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
|
|
|
|
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
|
|
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
|
|
w_fp4 = w_fp4.T
|
|
w_sf = w_sf.T
|
|
|
|
# Dequant reference
|
|
def dequant_a(fp4, sf, M, K):
|
|
u8 = fp4.view(torch.uint8)
|
|
lo = (u8 & 0x0F).long()
|
|
hi = ((u8 >> 4) & 0x0F).long()
|
|
nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
|
|
signs = (nib >> 3).float() * -2 + 1
|
|
mags = _E2M1_MAGNITUDES.to(device)[(nib & 0x07)]
|
|
sf_exp = sf.to(torch.float32).repeat_interleave(16, dim=-1)
|
|
return (signs * mags * sf_exp).to(torch.bfloat16)
|
|
|
|
def dequant_b(fp4, sf, K, N):
|
|
u8 = fp4.view(torch.uint8)
|
|
lo = (u8 & 0x0F).long()
|
|
hi = ((u8 >> 4) & 0x0F).long()
|
|
nib = torch.stack([lo, hi], dim=-1).reshape(u8.shape[0]*2, u8.shape[1])
|
|
signs = (nib >> 3).float() * -2 + 1
|
|
mags = _E2M1_MAGNITUDES.to(device)[(nib & 0x07)]
|
|
sf_exp = sf.to(torch.float32).repeat_interleave(16, dim=0)
|
|
return (signs * mags * sf_exp).to(torch.bfloat16)
|
|
|
|
x_recon = dequant_a(x_fp4, x_sf, M, K)
|
|
w_recon = dequant_b(w_fp4, w_sf, K, N)
|
|
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
|
|
|
|
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
|
|
mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item()
|
|
print(f"{label}: M={M} N={N} K={K} cosine={cos:.6f} mse={mse:.4e}")
|
|
|
|
# All at N=32, K=32 (same as the working all-ones test)
|
|
test(1, 32, 32, "RAND-TINY")
|
|
test(4, 32, 32, "RAND-M4")
|
|
test(128, 32, 32, "RAND-M128")
|
|
|
|
# Bigger
|
|
test(1, 128, 256, "RAND-128x256")
|
|
test(1, 256, 512, "RAND-256x512")
|
|
test(128, 256, 512, "RAND-128x256x512")
|
|
|
|
# Test with alpha != 1.0
|
|
print("\n--- alpha test ---")
|
|
M, N, K = 1, 32, 32
|
|
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
|
|
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
|
|
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
|
|
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
|
|
w_fp4 = w_fp4.T; w_sf = w_sf.T
|
|
|
|
x_recon = dequant_a(x_fp4, x_sf, M, K)
|
|
w_recon = dequant_b(w_fp4, w_sf, K, N)
|
|
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
|
|
|
|
for alpha in [1.0, 0.5, 2.0, 1e-3, 4.6e-5]:
|
|
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=alpha)
|
|
ref_scaled = quant_ref * alpha
|
|
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), ref_scaled.float(), dim=-1).item()
|
|
print(f" alpha={alpha:.1e} cosine={cos:.6f}")
|