NVFP4-1.1: update test kernel with threshold rounding API

This commit is contained in:
2026-05-28 04:27:29 +00:00
parent dabcc415a8
commit accc66741d

View File

@@ -1,11 +1,11 @@
"""
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
Two-pass approach to avoid Python list indexing with CuTeDSL loop variables:
Two-pass approach:
1. First pass: load 16 values, compute amax + FP8 scale
2. Second pass: reload 16 values, quantize to FP4
cute.arch.store confirmed available on B200.
Uses threshold rounding (no float-to-int conversion).
"""
import torch
@@ -32,20 +32,15 @@ def fp4_quant_test_kernel(
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
):
"""Quantize 16 BF16 values to NVFP4.
Two-pass: (1) compute amax+scale, (2) quantize+pack.
Only thread 0 works. Grid: (1,1,1), Block: (32,1,1).
"""
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
# ── Pass 1: Compute per-16-element amax ──
# Pass 1: amax
amax = cutlass.Float32(0.0)
for i in cutlass.range(16, unroll=1):
ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes
ptr = input_bf16.iterator + i * cutlass.Int32(2)
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
v = bf16_val.to(cutlass.Float32) / gs
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
@@ -56,35 +51,29 @@ def fp4_quant_test_kernel(
if amax < cutlass.Float32(6.0 * (2.0 ** -9)):
bsf_f32 = cutlass.Float32(0.0)
# FP8 E4M3 cast + dequant
sf_bits = fp8_e4m3_from_float32(bsf_f32)
bs_dequant = fp8_e4m3_to_float32(sf_bits)
# Write FP8 scale byte (Int32 holding uint8)
# Write SF
cute.arch.store(out_sf.iterator, sf_bits)
# ── Pass 2: Quantize and pack ──
# Pass 2: quantize and pack
for i in cutlass.range(8, unroll=1):
# Load even element (2*i)
ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2)
v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16)
v0 = v0_bf16.to(cutlass.Float32) / gs
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
# Load odd element (2*i+1)
ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2)
v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16)
v1 = v1_bf16.to(cutlass.Float32) / gs
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
# Write packed byte as Int32
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed)
def run_test():
"""Run the FP4 quantization test on GPU."""
device = "cuda"
N = 16
@@ -152,6 +141,7 @@ def run_test():
if __name__ == "__main__":
print("=" * 60)
print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test")
print("Threshold rounding — no float-to-int conversion")
print("=" * 60)
success = run_test()
exit(0 if success else 1)