NVFP4-1.1: fix test - two-pass kernel, cute.arch.store confirmed on B200
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
"""
|
||||
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
|
||||
|
||||
Tests that fp4_quant.py functions produce bit-exact matches with the
|
||||
Python reference (quantize_activation_nvfp4). Runs on B200 only.
|
||||
Two-pass approach to avoid Python list indexing with CuTeDSL loop variables:
|
||||
1. First pass: load 16 values, compute amax + FP8 scale
|
||||
2. Second pass: reload 16 values, quantize to FP4
|
||||
|
||||
Uses cute.arch.load for scalar GMEM reads.
|
||||
Uses cute.copy with CopyUniversalOp for GMEM writes (no cute.arch.store).
|
||||
cute.arch.store confirmed available on B200.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
import sys
|
||||
import os
|
||||
|
||||
@@ -22,7 +21,6 @@ from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
||||
from dsv4.kernels.gemm.fp4_quant import (
|
||||
fp8_e4m3_from_float32,
|
||||
fp8_e4m3_to_float32,
|
||||
half_step_to_e2m1_idx,
|
||||
quantize_e2m1_nibble,
|
||||
)
|
||||
|
||||
@@ -36,67 +34,53 @@ def fp4_quant_test_kernel(
|
||||
):
|
||||
"""Quantize 16 BF16 values to NVFP4.
|
||||
|
||||
Thread 0 does all work. Results written via cute.copy.
|
||||
Two-pass: (1) compute amax+scale, (2) quantize+pack.
|
||||
Only thread 0 works. Grid: (1,1,1), Block: (32,1,1).
|
||||
"""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
# Create a copy atom for Int32 GMEM writes (1 element per copy)
|
||||
copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=32,
|
||||
)
|
||||
|
||||
if tidx == cutlass.Int32(0):
|
||||
# Load global scale
|
||||
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
|
||||
|
||||
# Load 16 BF16 values, convert to FP32, normalize
|
||||
vals_f32 = [cutlass.Float32(0.0)] * 16
|
||||
for i in cutlass.range(16, unroll=1):
|
||||
bf16_val = cute.arch.load(
|
||||
input_bf16.iterator + i * cutlass.Int32(2),
|
||||
cutlass.BFloat16,
|
||||
)
|
||||
vals_f32[i] = bf16_val.to(cutlass.Float32) / gs
|
||||
|
||||
# Per-16-element amax
|
||||
# ── Pass 1: Compute per-16-element amax ──
|
||||
amax = cutlass.Float32(0.0)
|
||||
for i in cutlass.range(16, unroll=1):
|
||||
v = vals_f32[i]
|
||||
ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes
|
||||
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
|
||||
v = bf16_val.to(cutlass.Float32) / gs
|
||||
a = cute.math.fmax(v, cutlass.Float32(0.0) - v)
|
||||
amax = cute.math.fmax(amax, a)
|
||||
|
||||
# Block scale
|
||||
bsf_f32 = amax / cutlass.Float32(6.0)
|
||||
underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9))
|
||||
if amax < underflow_threshold:
|
||||
if amax < cutlass.Float32(6.0 * (2.0 ** -9)):
|
||||
bsf_f32 = cutlass.Float32(0.0)
|
||||
|
||||
# FP8 E4M3 cast + dequant
|
||||
sf_bits = fp8_e4m3_from_float32(bsf_f32)
|
||||
bs_dequant = fp8_e4m3_to_float32(sf_bits)
|
||||
|
||||
# E2M1 quantize + pack
|
||||
for i in cutlass.range(8, unroll=1):
|
||||
nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant)
|
||||
nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant)
|
||||
packed = (nibble1 << cutlass.Int32(4)) | nibble0
|
||||
# Write to GMEM via cute.copy (1 Int32 element)
|
||||
rmem = cute.make_rmem_tensor((1,), cutlass.Int32)
|
||||
rmem[cutlass.Int32(0)] = packed
|
||||
gmem = cute.make_tensor(
|
||||
out_fp4.iterator + i * cutlass.Int32(4),
|
||||
cute.make_layout((1,)),
|
||||
)
|
||||
cute.copy(copy_atom, rmem, gmem)
|
||||
# Write FP8 scale byte (Int32 holding uint8)
|
||||
cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32)
|
||||
|
||||
# Write FP8 scale
|
||||
rmem_sf = cute.make_rmem_tensor((1,), cutlass.Int32)
|
||||
rmem_sf[cutlass.Int32(0)] = sf_bits
|
||||
gmem_sf = cute.make_tensor(
|
||||
out_sf.iterator,
|
||||
cute.make_layout((1,)),
|
||||
)
|
||||
cute.copy(copy_atom, rmem_sf, gmem_sf)
|
||||
# ── Pass 2: Quantize and pack ──
|
||||
for i in cutlass.range(8, unroll=1):
|
||||
# Load even element (2*i)
|
||||
ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2)
|
||||
v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16)
|
||||
v0 = v0_bf16.to(cutlass.Float32) / gs
|
||||
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
|
||||
|
||||
# Load odd element (2*i+1)
|
||||
ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2)
|
||||
v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16)
|
||||
v1 = v1_bf16.to(cutlass.Float32) / gs
|
||||
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
|
||||
|
||||
packed = (nibble1 << cutlass.Int32(4)) | nibble0
|
||||
|
||||
# Write packed byte as Int32
|
||||
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
|
||||
|
||||
|
||||
def run_test():
|
||||
|
||||
Reference in New Issue
Block a user