NVFP4-1.1: add CuTeDSL kernel test for FP4 quantization

This commit is contained in:
2026-05-28 03:40:54 +00:00
parent 80b6b79f9e
commit 3a78bdf570

View File

@@ -1,20 +1,21 @@
"""
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL.
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
Tests the fp4_quant.py functions on B200. Compares CuTeDSL kernel output
with Python reference (quantize_activation_nvfp4).
Tests that fp4_quant.py functions produce bit-exact matches with the
Python reference (quantize_activation_nvfp4). Runs on B200 only.
The kernel takes 16 BF16 values + global_scale, quantizes to NVFP4,
and writes FP4 packed bytes + FP8 scale byte to output tensors.
Strategy: Launch a kernel that processes 16 BF16 values through the
quantization pipeline and writes results to GMEM. Compare with Python.
Uses cute.arch.load for scalar GMEM reads (proven pattern from the codebase).
For writes, uses the output tensor's iterator + offset pattern.
Uses cute.arch.load for scalar GMEM reads (proven pattern).
For GMEM writes, uses cute.copy with a simple CopyUniversalOp atom.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
from cutlass.cute.nvgpu import cpasync
import sys
import os
@@ -22,7 +23,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
from dsv4.kernels.gemm.fp4_quant import (
fp8_e4m3_from_float32_manual,
fp8_e4m3_from_float32,
fp8_e4m3_to_float32,
half_step_to_e2m1_idx,
quantize_e2m1_nibble,
@@ -32,7 +33,8 @@ from dsv4.kernels.gemm.fp4_quant import (
@cute.kernel
def fp4_quant_test_kernel(
input_bf16: cute.Tensor, # (16,) BF16 — 16 input values
out_data: cute.Tensor, # (10,) Int32 — [0..7] = FP4 packed bytes, [8] = SF byte, [9] = debug
out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
):
"""Quantize 16 BF16 values to NVFP4 using fp4_quant functions.
@@ -55,43 +57,40 @@ def fp4_quant_test_kernel(
)
vals_f32[i] = bf16_val.to(cutlass.Float32) / gs
# ── Compute per-16-element amax ──
# Compute per-16-element amax
amax = cutlass.Float32(0.0)
for i in cutlass.range(16, unroll=1):
v = vals_f32[i]
a = cute.math.fmax(v, cutlass.Float32(0.0) - v) # abs
amax = cute.math.fmax(amax, a)
# ── Block scale = amax / 6 ──
# Block scale = amax / 6
bsf_f32 = amax / cutlass.Float32(6.0)
# Underflow: if amax < 6 * 2^-9, force scale = 0
underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9))
if amax < underflow_threshold:
bsf_f32 = cutlass.Float32(0.0)
# ── FP8 E4M3 cast ──
sf_bits = fp8_e4m3_from_float32_manual(bsf_f32)
# FP8 E4M3 cast
sf_bits = fp8_e4m3_from_float32(bsf_f32)
# ── Dequantize FP8 scale (round-trip) ──
# Dequantize FP8 scale (round-trip)
bs_dequant = fp8_e4m3_to_float32(sf_bits)
# ── Quantize each value to E2M1 and pack ──
# Quantize each value to E2M1 and pack
for i in cutlass.range(8, unroll=1):
nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant)
nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
# Write packed byte as Int32
cute.arch.store(out_data.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
# ── Write FP8 scale byte ──
cute.arch.store(out_data.iterator + cutlass.Int32(8) * cutlass.Int32(4), sf_bits, cutlass.Int32)
# ── Debug: write bsf_f32 and bs_dequant as float ──
# out_data[9] is unused — let's skip for simplicity
# Write FP8 scale byte
cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32)
def run_test():
"""Run the FP4 quantization test."""
"""Run the FP4 quantization test on GPU."""
device = "cuda"
N = 16
@@ -99,7 +98,7 @@ def run_test():
torch.manual_seed(42)
x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device)
# Compute global scale (matching quantize_activation_nvfp4)
# Compute global scale
x_f32 = x_bf16.float()
amax_val = x_f32.abs().max().item()
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
@@ -111,11 +110,12 @@ def run_test():
print(f"Input BF16 (first 8): {x_bf16[0, :8].cpu()}")
print(f"Global scale: {global_scale:.8f}")
print(f"Ref FP4 bytes: {ref_fp4_bytes}")
print(f"Ref SF byte: {ref_sf_bytes}")
print(f"Ref FP4: {ref_fp4_bytes}")
print(f"Ref SF: {ref_sf_bytes}")
# Prepare output tensor
out_data = torch.zeros(10, dtype=torch.int32, device=device)
# Prepare output tensors
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device)
# Convert to CuTe tensors
@@ -125,7 +125,8 @@ def run_test():
x_flat = x_bf16.reshape(N).contiguous()
input_c = to_cute(x_flat)
out_c = to_cute(out_data)
out_fp4_c = to_cute(out_fp4)
out_sf_c = to_cute(out_sf)
gs_c = to_cute(gs_tensor)
# Compile and run
@@ -133,39 +134,44 @@ def run_test():
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("\nCompiling kernel (first run may take a minute)...")
compiled = cute.compile(
fp4_quant_test_kernel,
input_c, out_c, gs_c,
stream,
)
print("Compiled. Running...")
compiled(input_c, out_c, gs_c, stream)
torch.cuda.synchronize()
# Extract results
our_fp4 = out_data[:8].to(torch.uint8).cpu()
our_sf = out_data[8].to(torch.uint8).cpu().item()
print(f"\nOur FP4 bytes: {our_fp4}")
print(f"Our SF byte: {our_sf}")
# Compare
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
sf_match = our_sf == ref_sf_bytes[0].item()
if fp4_match and sf_match:
print("\n✅ PASS: FP4 quantization matches Python reference!")
return True
else:
print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}")
if not fp4_match:
for i in range(8):
o = our_fp4[i].item()
r = ref_fp4_bytes[i].item()
if o != r:
print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}")
if not sf_match:
print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}")
try:
compiled = cute.compile(
fp4_quant_test_kernel,
input_c, out_fp4_c, out_sf_c, gs_c,
stream,
)
print("Compiled. Running...")
compiled(input_c, out_fp4_c, out_sf_c, gs_c, stream)
torch.cuda.synchronize()
# Extract results
our_fp4 = out_fp4[:8].to(torch.uint8).cpu()
our_sf = out_sf[0].to(torch.uint8).cpu().item()
print(f"\nOur FP4: {our_fp4}")
print(f"Our SF: {our_sf}")
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
sf_match = our_sf == ref_sf_bytes[0].item()
if fp4_match and sf_match:
print("\n✅ PASS: FP4 quantization matches Python reference!")
return True
else:
print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}")
if not fp4_match:
for i in range(8):
o = our_fp4[i].item()
r = ref_fp4_bytes[i].item()
if o != r:
print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}")
if not sf_match:
print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}")
return False
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
return False