NVFP4-1.1: fix test kernel - use cute.copy instead of cute.arch.store

This commit is contained in:
2026-05-28 03:42:24 +00:00
parent 3a78bdf570
commit a41de129cb

View File

@@ -4,11 +4,8 @@ NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
Tests that fp4_quant.py functions produce bit-exact matches with the
Python reference (quantize_activation_nvfp4). Runs on B200 only.
Strategy: Launch a kernel that processes 16 BF16 values through the
quantization pipeline and writes results to GMEM. Compare with Python.
Uses cute.arch.load for scalar GMEM reads (proven pattern).
For GMEM writes, uses cute.copy with a simple CopyUniversalOp atom.
Uses cute.arch.load for scalar GMEM reads.
Uses cute.copy with CopyUniversalOp for GMEM writes (no cute.arch.store).
"""
import torch
@@ -32,61 +29,74 @@ from dsv4.kernels.gemm.fp4_quant import (
@cute.kernel
def fp4_quant_test_kernel(
input_bf16: cute.Tensor, # (16,) BF16 — 16 input values
input_bf16: cute.Tensor, # (16,) BF16
out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
):
"""Quantize 16 BF16 values to NVFP4 using fp4_quant functions.
"""Quantize 16 BF16 values to NVFP4.
Single-thread kernel (only thread 0 does work).
Grid: (1, 1, 1), Block: (32, 1, 1)
Thread 0 does all work. Results written via cute.copy.
"""
tidx, _, _ = cute.arch.thread_idx()
# Create a copy atom for Int32 GMEM writes (1 element per copy)
copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=32,
)
if tidx == cutlass.Int32(0):
# Load global scale
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
# Load 16 BF16 values, convert to FP32, normalize by global_scale
# Load 16 BF16 values, convert to FP32, normalize
vals_f32 = [cutlass.Float32(0.0)] * 16
for i in cutlass.range(16, unroll=1):
bf16_val = cute.arch.load(
input_bf16.iterator + i * cutlass.Int32(2), # BF16 = 2 bytes
input_bf16.iterator + i * cutlass.Int32(2),
cutlass.BFloat16,
)
vals_f32[i] = bf16_val.to(cutlass.Float32) / gs
# Compute per-16-element amax
# Per-16-element amax
amax = cutlass.Float32(0.0)
for i in cutlass.range(16, unroll=1):
v = vals_f32[i]
a = cute.math.fmax(v, cutlass.Float32(0.0) - v) # abs
a = cute.math.fmax(v, cutlass.Float32(0.0) - v)
amax = cute.math.fmax(amax, a)
# Block scale = amax / 6
# Block scale
bsf_f32 = amax / cutlass.Float32(6.0)
# Underflow: if amax < 6 * 2^-9, force scale = 0
underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9))
if amax < underflow_threshold:
bsf_f32 = cutlass.Float32(0.0)
# FP8 E4M3 cast
# FP8 E4M3 cast + dequant
sf_bits = fp8_e4m3_from_float32(bsf_f32)
# Dequantize FP8 scale (round-trip)
bs_dequant = fp8_e4m3_to_float32(sf_bits)
# Quantize each value to E2M1 and pack
# E2M1 quantize + pack
for i in cutlass.range(8, unroll=1):
nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant)
nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
# Write packed byte as Int32
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
# Write to GMEM via cute.copy (1 Int32 element)
rmem = cute.make_rmem_tensor((1,), cutlass.Int32)
rmem[cutlass.Int32(0)] = packed
gmem = cute.make_tensor(
out_fp4.iterator + i * cutlass.Int32(4),
cute.make_layout((1,)),
)
cute.copy(copy_atom, rmem, gmem)
# Write FP8 scale byte
cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32)
# Write FP8 scale
rmem_sf = cute.make_rmem_tensor((1,), cutlass.Int32)
rmem_sf[cutlass.Int32(0)] = sf_bits
gmem_sf = cute.make_tensor(
out_sf.iterator,
cute.make_layout((1,)),
)
cute.copy(copy_atom, rmem_sf, gmem_sf)
def run_test():
@@ -94,31 +104,25 @@ def run_test():
device = "cuda"
N = 16
# Generate test input
torch.manual_seed(42)
x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device)
# Compute global scale
x_f32 = x_bf16.float()
amax_val = x_f32.abs().max().item()
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
# Python reference
ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale)
ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu()
ref_sf_bytes = ref_sf.view(torch.uint8).cpu()
print(f"Input BF16 (first 8): {x_bf16[0, :8].cpu()}")
print(f"Global scale: {global_scale:.8f}")
print(f"Ref FP4: {ref_fp4_bytes}")
print(f"Ref SF: {ref_sf_bytes}")
# Prepare output tensors
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device)
# Convert to CuTe tensors
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
@@ -129,44 +133,30 @@ def run_test():
out_sf_c = to_cute(out_sf)
gs_c = to_cute(gs_tensor)
# Compile and run
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("\nCompiling kernel (first run may take a minute)...")
print("\nCompiling kernel...")
try:
compiled = cute.compile(
fp4_quant_test_kernel,
input_c, out_fp4_c, out_sf_c, gs_c,
stream,
)
print("Compiled. Running...")
compiled(input_c, out_fp4_c, out_sf_c, gs_c, stream)
compiled(input_c, out_fp4_c, out_sf_c, gs_c)
torch.cuda.synchronize()
# Extract results
our_fp4 = out_fp4[:8].to(torch.uint8).cpu()
our_sf = out_sf[0].to(torch.uint8).cpu().item()
print(f"\nOur FP4: {our_fp4}")
print(f"Our FP4: {our_fp4}")
print(f"Our SF: {our_sf}")
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
sf_match = our_sf == ref_sf_bytes[0].item()
if fp4_match and sf_match:
print("\n✅ PASS: FP4 quantization matches Python reference!")
print("\n✅ PASS!")
return True
else:
print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}")
if not fp4_match:
for i in range(8):
o = our_fp4[i].item()
r = ref_fp4_bytes[i].item()
if o != r:
print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}")
if not sf_match:
print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}")
print(f"\n❌ FAIL: FP4={fp4_match} SF={sf_match}")
return False
except Exception as e:
print(f"\n❌ ERROR: {e}")
@@ -178,7 +168,6 @@ def run_test():
if __name__ == "__main__":
print("=" * 60)
print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test")
print("Verifies fp4_quant.py functions match Python reference")
print("=" * 60)
success = run_test()
exit(0 if success else 1)