102 lines
4.0 KiB
Python
102 lines
4.0 KiB
Python
"""Standalone test matching real MoE dimensions: M=1, N=6144, K=7168.
|
|
|
|
The random test with M=128 showed cosine 1.0, but real inference with M=1
|
|
shows cosine ≈ 0. This test uses deterministic data at M=1 to reproduce.
|
|
"""
|
|
import torch
|
|
import sys
|
|
sys.path.insert(0, 'src')
|
|
|
|
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import (
|
|
cutlass_nvfp4_blockscaled_gemm,
|
|
)
|
|
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
|
|
|
|
torch.manual_seed(42)
|
|
device = "cuda"
|
|
|
|
M, N, K = 1, 6144, 7168
|
|
K_half = K // 2
|
|
|
|
# Create BF16 reference data
|
|
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
|
|
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
|
|
|
|
# Reference BF16 GEMM
|
|
ref_out = torch.nn.functional.linear(x_bf16, w_bf16.T) # (M, N)
|
|
print(f"BF16 ref: amax={ref_out.abs().max():.4e} mean={ref_out.mean():.4e}")
|
|
|
|
# Quantize to NVFP4
|
|
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float()) # (M, K_half) int8, (M, K//16) float8
|
|
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.float()) # (K, N_half) int8, (K, N//16) float8
|
|
|
|
# Need w in (K_half, N) layout for CUTLASS
|
|
# w_bf16 is (K, N). Quantize gives w_fp4 (K, N//2). Need (K//2, N) = (3584, 6144)
|
|
# Wait — the weight layout for CUTLASS B is (K_half, N) where the original is (K, N)
|
|
# But _quantize_to_e2m1 on (K, N) gives (K, N//2) which is (7168, 3072)
|
|
# We need (3584, 6144) = (K_half, N)
|
|
# So we should quantize w_bf16.T instead: (N, K) → (N, K//2) → transpose to (K//2, N)
|
|
w_t = w_bf16.T # (N, K) = (6144, 7168)
|
|
w_fp4_t, w_sf_t = _quantize_to_e2m1(w_t.float()) # (N, K//2) = (6144, 3584)
|
|
w_fp4_final = w_fp4_t.T # (K//2, N) = (3584, 6144)
|
|
w_sf_final = w_sf_t.T # (K//16, N) = (448, 6144)
|
|
|
|
print(f"x_fp4: {x_fp4.shape} x_sf: {x_sf.shape}")
|
|
print(f"w_fp4: {w_fp4_final.shape} w_sf: {w_sf_final.shape}")
|
|
|
|
# Dequantize and compute reference from quantized values
|
|
x_u8 = x_fp4.view(torch.uint8)
|
|
lo = (x_u8 & 0x0F).long()
|
|
hi = ((x_u8 >> 4) & 0x0F).long()
|
|
x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
|
|
x_signs = (x_nib >> 3).float() * -2 + 1
|
|
x_mags = _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)]
|
|
x_deq = x_signs * x_mags
|
|
sf_exp = x_sf.to(torch.float32).repeat_interleave(16, dim=-1)
|
|
x_recon = (x_deq * sf_exp).to(torch.bfloat16)
|
|
|
|
w_u8 = w_fp4_final.view(torch.uint8)
|
|
wlo = (w_u8 & 0x0F).long()
|
|
whi = ((w_u8 >> 4) & 0x0F).long()
|
|
w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1])
|
|
w_signs = (w_nib >> 3).float() * -2 + 1
|
|
w_mags = _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)]
|
|
w_deq = w_signs * w_mags
|
|
w_sf_exp = w_sf_final.to(torch.float32).repeat_interleave(16, dim=0)
|
|
w_recon = (w_deq * w_sf_exp).to(torch.bfloat16)
|
|
|
|
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
|
|
print(f"Quant ref: amax={quant_ref.abs().max():.4e} mean={quant_ref.mean():.4e}")
|
|
|
|
# Run CUTLASS GEMM
|
|
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(
|
|
x_fp4, x_sf,
|
|
w_fp4_final, w_sf_final,
|
|
M, N, K,
|
|
alpha=1.0,
|
|
)
|
|
print(f"NVFP4 out: amax={nvfp4_out.abs().max():.4e} mean={nvfp4_out.mean():.4e}")
|
|
|
|
# Cosine similarity
|
|
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
|
|
mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item()
|
|
print(f"cosine={cos:.6f} mse={mse:.4e}")
|
|
|
|
# Also test with M=128
|
|
M2 = 128
|
|
x2 = torch.randn(M2, K, dtype=torch.bfloat16, device=device) * 2.0
|
|
x2_fp4, x2_sf = _quantize_to_e2m1(x2.float())
|
|
x2_u8 = x2_fp4.view(torch.uint8)
|
|
lo2 = (x2_u8 & 0x0F).long()
|
|
hi2 = ((x2_u8 >> 4) & 0x0F).long()
|
|
x2_nib = torch.stack([lo2, hi2], dim=-1).reshape(M2, -1)
|
|
x2_signs = (x2_nib >> 3).float() * -2 + 1
|
|
x2_mags = _E2M1_MAGNITUDES.to(device)[(x2_nib & 0x07)]
|
|
x2_deq = x2_signs * x2_mags
|
|
sf2_exp = x2_sf.to(torch.float32).repeat_interleave(16, dim=-1)
|
|
x2_recon = (x2_deq * sf2_exp).to(torch.bfloat16)
|
|
qr2 = torch.nn.functional.linear(x2_recon, w_recon.T)
|
|
nv2 = cutlass_nvfp4_blockscaled_gemm(x2_fp4, x2_sf, w_fp4_final, w_sf_final, M2, N, K, alpha=1.0)
|
|
cos2 = torch.nn.functional.cosine_similarity(nv2.float(), qr2.float(), dim=-1).mean().item()
|
|
print(f"M=128: cosine={cos2:.6f}")
|