NVFP4-1.1: add CuTeDSL kernel test for FP4 quantization
This commit is contained in:
@@ -1,20 +1,21 @@
|
||||
"""
|
||||
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL.
|
||||
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
|
||||
|
||||
Tests the fp4_quant.py functions on B200. Compares CuTeDSL kernel output
|
||||
with Python reference (quantize_activation_nvfp4).
|
||||
Tests that fp4_quant.py functions produce bit-exact matches with the
|
||||
Python reference (quantize_activation_nvfp4). Runs on B200 only.
|
||||
|
||||
The kernel takes 16 BF16 values + global_scale, quantizes to NVFP4,
|
||||
and writes FP4 packed bytes + FP8 scale byte to output tensors.
|
||||
Strategy: Launch a kernel that processes 16 BF16 values through the
|
||||
quantization pipeline and writes results to GMEM. Compare with Python.
|
||||
|
||||
Uses cute.arch.load for scalar GMEM reads (proven pattern from the codebase).
|
||||
For writes, uses the output tensor's iterator + offset pattern.
|
||||
Uses cute.arch.load for scalar GMEM reads (proven pattern).
|
||||
For GMEM writes, uses cute.copy with a simple CopyUniversalOp atom.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
import sys
|
||||
import os
|
||||
|
||||
@@ -22,7 +23,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
||||
from dsv4.kernels.gemm.fp4_quant import (
|
||||
fp8_e4m3_from_float32_manual,
|
||||
fp8_e4m3_from_float32,
|
||||
fp8_e4m3_to_float32,
|
||||
half_step_to_e2m1_idx,
|
||||
quantize_e2m1_nibble,
|
||||
@@ -32,7 +33,8 @@ from dsv4.kernels.gemm.fp4_quant import (
|
||||
@cute.kernel
|
||||
def fp4_quant_test_kernel(
|
||||
input_bf16: cute.Tensor, # (16,) BF16 — 16 input values
|
||||
out_data: cute.Tensor, # (10,) Int32 — [0..7] = FP4 packed bytes, [8] = SF byte, [9] = debug
|
||||
out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes
|
||||
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
|
||||
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
|
||||
):
|
||||
"""Quantize 16 BF16 values to NVFP4 using fp4_quant functions.
|
||||
@@ -55,43 +57,40 @@ def fp4_quant_test_kernel(
|
||||
)
|
||||
vals_f32[i] = bf16_val.to(cutlass.Float32) / gs
|
||||
|
||||
# ── Compute per-16-element amax ──
|
||||
# Compute per-16-element amax
|
||||
amax = cutlass.Float32(0.0)
|
||||
for i in cutlass.range(16, unroll=1):
|
||||
v = vals_f32[i]
|
||||
a = cute.math.fmax(v, cutlass.Float32(0.0) - v) # abs
|
||||
amax = cute.math.fmax(amax, a)
|
||||
|
||||
# ── Block scale = amax / 6 ──
|
||||
# Block scale = amax / 6
|
||||
bsf_f32 = amax / cutlass.Float32(6.0)
|
||||
# Underflow: if amax < 6 * 2^-9, force scale = 0
|
||||
underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9))
|
||||
if amax < underflow_threshold:
|
||||
bsf_f32 = cutlass.Float32(0.0)
|
||||
|
||||
# ── FP8 E4M3 cast ──
|
||||
sf_bits = fp8_e4m3_from_float32_manual(bsf_f32)
|
||||
# FP8 E4M3 cast
|
||||
sf_bits = fp8_e4m3_from_float32(bsf_f32)
|
||||
|
||||
# ── Dequantize FP8 scale (round-trip) ──
|
||||
# Dequantize FP8 scale (round-trip)
|
||||
bs_dequant = fp8_e4m3_to_float32(sf_bits)
|
||||
|
||||
# ── Quantize each value to E2M1 and pack ──
|
||||
# Quantize each value to E2M1 and pack
|
||||
for i in cutlass.range(8, unroll=1):
|
||||
nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant)
|
||||
nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant)
|
||||
packed = (nibble1 << cutlass.Int32(4)) | nibble0
|
||||
# Write packed byte as Int32
|
||||
cute.arch.store(out_data.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
|
||||
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
|
||||
|
||||
# ── Write FP8 scale byte ──
|
||||
cute.arch.store(out_data.iterator + cutlass.Int32(8) * cutlass.Int32(4), sf_bits, cutlass.Int32)
|
||||
|
||||
# ── Debug: write bsf_f32 and bs_dequant as float ──
|
||||
# out_data[9] is unused — let's skip for simplicity
|
||||
# Write FP8 scale byte
|
||||
cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32)
|
||||
|
||||
|
||||
def run_test():
|
||||
"""Run the FP4 quantization test."""
|
||||
"""Run the FP4 quantization test on GPU."""
|
||||
device = "cuda"
|
||||
N = 16
|
||||
|
||||
@@ -99,7 +98,7 @@ def run_test():
|
||||
torch.manual_seed(42)
|
||||
x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Compute global scale (matching quantize_activation_nvfp4)
|
||||
# Compute global scale
|
||||
x_f32 = x_bf16.float()
|
||||
amax_val = x_f32.abs().max().item()
|
||||
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
|
||||
@@ -111,11 +110,12 @@ def run_test():
|
||||
|
||||
print(f"Input BF16 (first 8): {x_bf16[0, :8].cpu()}")
|
||||
print(f"Global scale: {global_scale:.8f}")
|
||||
print(f"Ref FP4 bytes: {ref_fp4_bytes}")
|
||||
print(f"Ref SF byte: {ref_sf_bytes}")
|
||||
print(f"Ref FP4: {ref_fp4_bytes}")
|
||||
print(f"Ref SF: {ref_sf_bytes}")
|
||||
|
||||
# Prepare output tensor
|
||||
out_data = torch.zeros(10, dtype=torch.int32, device=device)
|
||||
# Prepare output tensors
|
||||
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
|
||||
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
|
||||
gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device)
|
||||
|
||||
# Convert to CuTe tensors
|
||||
@@ -125,7 +125,8 @@ def run_test():
|
||||
|
||||
x_flat = x_bf16.reshape(N).contiguous()
|
||||
input_c = to_cute(x_flat)
|
||||
out_c = to_cute(out_data)
|
||||
out_fp4_c = to_cute(out_fp4)
|
||||
out_sf_c = to_cute(out_sf)
|
||||
gs_c = to_cute(gs_tensor)
|
||||
|
||||
# Compile and run
|
||||
@@ -133,39 +134,44 @@ def run_test():
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
print("\nCompiling kernel (first run may take a minute)...")
|
||||
compiled = cute.compile(
|
||||
fp4_quant_test_kernel,
|
||||
input_c, out_c, gs_c,
|
||||
stream,
|
||||
)
|
||||
print("Compiled. Running...")
|
||||
compiled(input_c, out_c, gs_c, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Extract results
|
||||
our_fp4 = out_data[:8].to(torch.uint8).cpu()
|
||||
our_sf = out_data[8].to(torch.uint8).cpu().item()
|
||||
|
||||
print(f"\nOur FP4 bytes: {our_fp4}")
|
||||
print(f"Our SF byte: {our_sf}")
|
||||
|
||||
# Compare
|
||||
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
|
||||
sf_match = our_sf == ref_sf_bytes[0].item()
|
||||
|
||||
if fp4_match and sf_match:
|
||||
print("\n✅ PASS: FP4 quantization matches Python reference!")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}")
|
||||
if not fp4_match:
|
||||
for i in range(8):
|
||||
o = our_fp4[i].item()
|
||||
r = ref_fp4_bytes[i].item()
|
||||
if o != r:
|
||||
print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}")
|
||||
if not sf_match:
|
||||
print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}")
|
||||
try:
|
||||
compiled = cute.compile(
|
||||
fp4_quant_test_kernel,
|
||||
input_c, out_fp4_c, out_sf_c, gs_c,
|
||||
stream,
|
||||
)
|
||||
print("Compiled. Running...")
|
||||
compiled(input_c, out_fp4_c, out_sf_c, gs_c, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Extract results
|
||||
our_fp4 = out_fp4[:8].to(torch.uint8).cpu()
|
||||
our_sf = out_sf[0].to(torch.uint8).cpu().item()
|
||||
|
||||
print(f"\nOur FP4: {our_fp4}")
|
||||
print(f"Our SF: {our_sf}")
|
||||
|
||||
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
|
||||
sf_match = our_sf == ref_sf_bytes[0].item()
|
||||
|
||||
if fp4_match and sf_match:
|
||||
print("\n✅ PASS: FP4 quantization matches Python reference!")
|
||||
return True
|
||||
else:
|
||||
print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}")
|
||||
if not fp4_match:
|
||||
for i in range(8):
|
||||
o = our_fp4[i].item()
|
||||
r = ref_fp4_bytes[i].item()
|
||||
if o != r:
|
||||
print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}")
|
||||
if not sf_match:
|
||||
print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user