Files
nvfp4-megamoe-kernel/test_m1_gemm.py

102 lines
4.0 KiB
Python

"""Standalone test matching real MoE dimensions: M=1, N=6144, K=7168.
The random test with M=128 showed cosine 1.0, but real inference with M=1
shows cosine ≈ 0. This test uses deterministic data at M=1 to reproduce.
"""
import torch
import sys
sys.path.insert(0, 'src')
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import (
cutlass_nvfp4_blockscaled_gemm,
)
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
torch.manual_seed(42)
device = "cuda"
M, N, K = 1, 6144, 7168
K_half = K // 2
# Create BF16 reference data
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
# Reference BF16 GEMM
ref_out = torch.nn.functional.linear(x_bf16, w_bf16.T) # (M, N)
print(f"BF16 ref: amax={ref_out.abs().max():.4e} mean={ref_out.mean():.4e}")
# Quantize to NVFP4
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float()) # (M, K_half) int8, (M, K//16) float8
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.float()) # (K, N_half) int8, (K, N//16) float8
# Need w in (K_half, N) layout for CUTLASS
# w_bf16 is (K, N). Quantize gives w_fp4 (K, N//2). Need (K//2, N) = (3584, 6144)
# Wait — the weight layout for CUTLASS B is (K_half, N) where the original is (K, N)
# But _quantize_to_e2m1 on (K, N) gives (K, N//2) which is (7168, 3072)
# We need (3584, 6144) = (K_half, N)
# So we should quantize w_bf16.T instead: (N, K) → (N, K//2) → transpose to (K//2, N)
w_t = w_bf16.T # (N, K) = (6144, 7168)
w_fp4_t, w_sf_t = _quantize_to_e2m1(w_t.float()) # (N, K//2) = (6144, 3584)
w_fp4_final = w_fp4_t.T # (K//2, N) = (3584, 6144)
w_sf_final = w_sf_t.T # (K//16, N) = (448, 6144)
print(f"x_fp4: {x_fp4.shape} x_sf: {x_sf.shape}")
print(f"w_fp4: {w_fp4_final.shape} w_sf: {w_sf_final.shape}")
# Dequantize and compute reference from quantized values
x_u8 = x_fp4.view(torch.uint8)
lo = (x_u8 & 0x0F).long()
hi = ((x_u8 >> 4) & 0x0F).long()
x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
x_signs = (x_nib >> 3).float() * -2 + 1
x_mags = _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)]
x_deq = x_signs * x_mags
sf_exp = x_sf.to(torch.float32).repeat_interleave(16, dim=-1)
x_recon = (x_deq * sf_exp).to(torch.bfloat16)
w_u8 = w_fp4_final.view(torch.uint8)
wlo = (w_u8 & 0x0F).long()
whi = ((w_u8 >> 4) & 0x0F).long()
w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1])
w_signs = (w_nib >> 3).float() * -2 + 1
w_mags = _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)]
w_deq = w_signs * w_mags
w_sf_exp = w_sf_final.to(torch.float32).repeat_interleave(16, dim=0)
w_recon = (w_deq * w_sf_exp).to(torch.bfloat16)
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
print(f"Quant ref: amax={quant_ref.abs().max():.4e} mean={quant_ref.mean():.4e}")
# Run CUTLASS GEMM
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(
x_fp4, x_sf,
w_fp4_final, w_sf_final,
M, N, K,
alpha=1.0,
)
print(f"NVFP4 out: amax={nvfp4_out.abs().max():.4e} mean={nvfp4_out.mean():.4e}")
# Cosine similarity
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item()
print(f"cosine={cos:.6f} mse={mse:.4e}")
# Also test with M=128
M2 = 128
x2 = torch.randn(M2, K, dtype=torch.bfloat16, device=device) * 2.0
x2_fp4, x2_sf = _quantize_to_e2m1(x2.float())
x2_u8 = x2_fp4.view(torch.uint8)
lo2 = (x2_u8 & 0x0F).long()
hi2 = ((x2_u8 >> 4) & 0x0F).long()
x2_nib = torch.stack([lo2, hi2], dim=-1).reshape(M2, -1)
x2_signs = (x2_nib >> 3).float() * -2 + 1
x2_mags = _E2M1_MAGNITUDES.to(device)[(x2_nib & 0x07)]
x2_deq = x2_signs * x2_mags
sf2_exp = x2_sf.to(torch.float32).repeat_interleave(16, dim=-1)
x2_recon = (x2_deq * sf2_exp).to(torch.bfloat16)
qr2 = torch.nn.functional.linear(x2_recon, w_recon.T)
nv2 = cutlass_nvfp4_blockscaled_gemm(x2_fp4, x2_sf, w_fp4_final, w_sf_final, M2, N, K, alpha=1.0)
cos2 = torch.nn.functional.cosine_similarity(nv2.float(), qr2.float(), dim=-1).mean().item()
print(f"M=128: cosine={cos2:.6f}")