Files
nvfp4-megamoe-kernel/tests/unit/test_nvfp4_1_1_quant.py

146 lines
4.7 KiB
Python

"""
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
Simplified: uses Float32 input to avoid BF16 scalar load issues.
Two-pass: (1) compute amax+scale, (2) quantize+pack.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
from dsv4.kernels.gemm.fp4_quant import (
fp8_e4m3_from_float32,
fp8_e4m3_to_float32,
quantize_e2m1_nibble,
)
@cute.kernel
def fp4_quant_test_kernel(
input_f32: cute.Tensor, # (16,) Float32
out_fp4: cute.Tensor, # (8,) Int32
out_sf: cute.Tensor, # (1,) Int32
gs_scalar: cute.Tensor, # (1,) Float32
):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
# Pass 1: amax (load Float32, divide by global_scale)
amax = cutlass.Float32(0.0)
for i in cutlass.range(16, unroll=1):
ptr = input_f32.iterator + i * cutlass.Int32(4) # Float32 = 4 bytes
v = cute.arch.load(ptr, cutlass.Float32) / gs
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
amax = cute.arch.fmax(amax, a)
# Block scale
bsf_f32 = amax / cutlass.Float32(6.0)
if amax < cutlass.Float32(6.0 * (2.0 ** -9)):
bsf_f32 = cutlass.Float32(0.0)
sf_bits = fp8_e4m3_from_float32(bsf_f32)
bs_dequant = fp8_e4m3_to_float32(sf_bits)
# Write SF
cute.arch.store(out_sf.iterator, sf_bits)
# Pass 2: quantize and pack
for i in cutlass.range(8, unroll=1):
ptr0 = input_f32.iterator + (2 * i) * cutlass.Int32(4)
v0 = cute.arch.load(ptr0, cutlass.Float32) / gs
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
ptr1 = input_f32.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(4)
v1 = cute.arch.load(ptr1, cutlass.Float32) / gs
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed)
def run_test():
device = "cuda"
N = 16
torch.manual_seed(42)
x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device)
x_f32 = x_bf16.float()
amax_val = x_f32.abs().max().item()
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
# Pre-divide by global_scale (kernel expects normalized input)
x_norm = (x_f32 / global_scale).reshape(N).contiguous()
# Python reference (on the same normalized values)
ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale)
ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu()
ref_sf_bytes = ref_sf.view(torch.uint8).cpu()
print(f"Global scale: {global_scale:.8f}")
print(f"Ref FP4: {ref_fp4_bytes}")
print(f"Ref SF: {ref_sf_bytes}")
# Output tensors
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
gs_tensor = torch.tensor([1.0], dtype=torch.float32, device=device) # already normalized
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
input_c = to_cute(x_norm)
out_fp4_c = to_cute(out_fp4)
out_sf_c = to_cute(out_sf)
gs_c = to_cute(gs_tensor)
print("\nCompiling kernel...")
try:
compiled = cute.compile(
fp4_quant_test_kernel,
input_c, out_fp4_c, out_sf_c, gs_c,
)
print("Compiled. Running...")
compiled(input_c, out_fp4_c, out_sf_c, gs_c)
torch.cuda.synchronize()
our_fp4 = out_fp4[:8].to(torch.uint8).cpu()
our_sf = out_sf[0].to(torch.uint8).cpu().item()
print(f"Our FP4: {our_fp4}")
print(f"Our SF: {our_sf}")
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
sf_match = our_sf == ref_sf_bytes[0].item()
if fp4_match and sf_match:
print("\n✅ PASS!")
return True
else:
print(f"\n❌ FAIL: FP4={fp4_match} SF={sf_match}")
return False
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("=" * 60)
print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test")
print("Threshold rounding — Float32 input — no BF16 scalar loads")
print("=" * 60)
success = run_test()
exit(0 if success else 1)