NVFP4-1.1: fix test - two-pass kernel, cute.arch.store confirmed on B200

This commit is contained in:
2026-05-28 03:46:45 +00:00
parent ca9f920414
commit 60790564f0

View File

@@ -1,18 +1,17 @@
"""
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
Tests that fp4_quant.py functions produce bit-exact matches with the
Python reference (quantize_activation_nvfp4). Runs on B200 only.
Two-pass approach to avoid Python list indexing with CuTeDSL loop variables:
1. First pass: load 16 values, compute amax + FP8 scale
2. Second pass: reload 16 values, quantize to FP4
Uses cute.arch.load for scalar GMEM reads.
Uses cute.copy with CopyUniversalOp for GMEM writes (no cute.arch.store).
cute.arch.store confirmed available on B200.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
from cutlass.cute.nvgpu import cpasync
import sys
import os
@@ -22,7 +21,6 @@ from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
from dsv4.kernels.gemm.fp4_quant import (
fp8_e4m3_from_float32,
fp8_e4m3_to_float32,
half_step_to_e2m1_idx,
quantize_e2m1_nibble,
)
@@ -36,67 +34,53 @@ def fp4_quant_test_kernel(
):
"""Quantize 16 BF16 values to NVFP4.
Thread 0 does all work. Results written via cute.copy.
Two-pass: (1) compute amax+scale, (2) quantize+pack.
Only thread 0 works. Grid: (1,1,1), Block: (32,1,1).
"""
tidx, _, _ = cute.arch.thread_idx()
# Create a copy atom for Int32 GMEM writes (1 element per copy)
copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=32,
)
if tidx == cutlass.Int32(0):
# Load global scale
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
# Load 16 BF16 values, convert to FP32, normalize
vals_f32 = [cutlass.Float32(0.0)] * 16
for i in cutlass.range(16, unroll=1):
bf16_val = cute.arch.load(
input_bf16.iterator + i * cutlass.Int32(2),
cutlass.BFloat16,
)
vals_f32[i] = bf16_val.to(cutlass.Float32) / gs
# Per-16-element amax
# ── Pass 1: Compute per-16-element amax ──
amax = cutlass.Float32(0.0)
for i in cutlass.range(16, unroll=1):
v = vals_f32[i]
ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
v = bf16_val.to(cutlass.Float32) / gs
a = cute.math.fmax(v, cutlass.Float32(0.0) - v)
amax = cute.math.fmax(amax, a)
# Block scale
bsf_f32 = amax / cutlass.Float32(6.0)
underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9))
if amax < underflow_threshold:
if amax < cutlass.Float32(6.0 * (2.0 ** -9)):
bsf_f32 = cutlass.Float32(0.0)
# FP8 E4M3 cast + dequant
sf_bits = fp8_e4m3_from_float32(bsf_f32)
bs_dequant = fp8_e4m3_to_float32(sf_bits)
# E2M1 quantize + pack
for i in cutlass.range(8, unroll=1):
nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant)
nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
# Write to GMEM via cute.copy (1 Int32 element)
rmem = cute.make_rmem_tensor((1,), cutlass.Int32)
rmem[cutlass.Int32(0)] = packed
gmem = cute.make_tensor(
out_fp4.iterator + i * cutlass.Int32(4),
cute.make_layout((1,)),
)
cute.copy(copy_atom, rmem, gmem)
# Write FP8 scale byte (Int32 holding uint8)
cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32)
# Write FP8 scale
rmem_sf = cute.make_rmem_tensor((1,), cutlass.Int32)
rmem_sf[cutlass.Int32(0)] = sf_bits
gmem_sf = cute.make_tensor(
out_sf.iterator,
cute.make_layout((1,)),
)
cute.copy(copy_atom, rmem_sf, gmem_sf)
# ── Pass 2: Quantize and pack ──
for i in cutlass.range(8, unroll=1):
# Load even element (2*i)
ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2)
v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16)
v0 = v0_bf16.to(cutlass.Float32) / gs
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
# Load odd element (2*i+1)
ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2)
v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16)
v1 = v1_bf16.to(cutlass.Float32) / gs
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
# Write packed byte as Int32
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
def run_test():