From 60790564f0703183eba00bd5b46242bcdf60b7b8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 03:46:45 +0000 Subject: [PATCH] NVFP4-1.1: fix test - two-pass kernel, cute.arch.store confirmed on B200 --- tests/unit/test_nvfp4_1_1_quant.py | 78 ++++++++++++------------------ 1 file changed, 31 insertions(+), 47 deletions(-) diff --git a/tests/unit/test_nvfp4_1_1_quant.py b/tests/unit/test_nvfp4_1_1_quant.py index e44a66ff..f2882760 100644 --- a/tests/unit/test_nvfp4_1_1_quant.py +++ b/tests/unit/test_nvfp4_1_1_quant.py @@ -1,18 +1,17 @@ """ NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel. -Tests that fp4_quant.py functions produce bit-exact matches with the -Python reference (quantize_activation_nvfp4). Runs on B200 only. +Two-pass approach to avoid Python list indexing with CuTeDSL loop variables: +1. First pass: load 16 values, compute amax + FP8 scale +2. Second pass: reload 16 values, quantize to FP4 -Uses cute.arch.load for scalar GMEM reads. -Uses cute.copy with CopyUniversalOp for GMEM writes (no cute.arch.store). +cute.arch.store confirmed available on B200. """ import torch import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch -from cutlass.cute.nvgpu import cpasync import sys import os @@ -22,7 +21,6 @@ from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE from dsv4.kernels.gemm.fp4_quant import ( fp8_e4m3_from_float32, fp8_e4m3_to_float32, - half_step_to_e2m1_idx, quantize_e2m1_nibble, ) @@ -36,67 +34,53 @@ def fp4_quant_test_kernel( ): """Quantize 16 BF16 values to NVFP4. - Thread 0 does all work. Results written via cute.copy. + Two-pass: (1) compute amax+scale, (2) quantize+pack. + Only thread 0 works. Grid: (1,1,1), Block: (32,1,1). """ tidx, _, _ = cute.arch.thread_idx() - # Create a copy atom for Int32 GMEM writes (1 element per copy) - copy_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=32, - ) - if tidx == cutlass.Int32(0): - # Load global scale gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32) - # Load 16 BF16 values, convert to FP32, normalize - vals_f32 = [cutlass.Float32(0.0)] * 16 - for i in cutlass.range(16, unroll=1): - bf16_val = cute.arch.load( - input_bf16.iterator + i * cutlass.Int32(2), - cutlass.BFloat16, - ) - vals_f32[i] = bf16_val.to(cutlass.Float32) / gs - - # Per-16-element amax + # ── Pass 1: Compute per-16-element amax ── amax = cutlass.Float32(0.0) for i in cutlass.range(16, unroll=1): - v = vals_f32[i] + ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes + bf16_val = cute.arch.load(ptr, cutlass.BFloat16) + v = bf16_val.to(cutlass.Float32) / gs a = cute.math.fmax(v, cutlass.Float32(0.0) - v) amax = cute.math.fmax(amax, a) # Block scale bsf_f32 = amax / cutlass.Float32(6.0) - underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9)) - if amax < underflow_threshold: + if amax < cutlass.Float32(6.0 * (2.0 ** -9)): bsf_f32 = cutlass.Float32(0.0) # FP8 E4M3 cast + dequant sf_bits = fp8_e4m3_from_float32(bsf_f32) bs_dequant = fp8_e4m3_to_float32(sf_bits) - # E2M1 quantize + pack - for i in cutlass.range(8, unroll=1): - nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant) - nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant) - packed = (nibble1 << cutlass.Int32(4)) | nibble0 - # Write to GMEM via cute.copy (1 Int32 element) - rmem = cute.make_rmem_tensor((1,), cutlass.Int32) - rmem[cutlass.Int32(0)] = packed - gmem = cute.make_tensor( - out_fp4.iterator + i * cutlass.Int32(4), - cute.make_layout((1,)), - ) - cute.copy(copy_atom, rmem, gmem) + # Write FP8 scale byte (Int32 holding uint8) + cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32) - # Write FP8 scale - rmem_sf = cute.make_rmem_tensor((1,), cutlass.Int32) - rmem_sf[cutlass.Int32(0)] = sf_bits - gmem_sf = cute.make_tensor( - out_sf.iterator, - cute.make_layout((1,)), - ) - cute.copy(copy_atom, rmem_sf, gmem_sf) + # ── Pass 2: Quantize and pack ── + for i in cutlass.range(8, unroll=1): + # Load even element (2*i) + ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2) + v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16) + v0 = v0_bf16.to(cutlass.Float32) / gs + nibble0 = quantize_e2m1_nibble(v0, bs_dequant) + + # Load odd element (2*i+1) + ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2) + v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16) + v1 = v1_bf16.to(cutlass.Float32) / gs + nibble1 = quantize_e2m1_nibble(v1, bs_dequant) + + packed = (nibble1 << cutlass.Int32(4)) | nibble0 + + # Write packed byte as Int32 + cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32) def run_test():