NVFP4-1.1: test kernel uses Float32 input (avoids BF16 scalar load issue)

This commit is contained in:
2026-05-28 04:32:08 +00:00
parent d2aa93aad7
commit 1828a71cde

View File

@@ -1,11 +1,8 @@
"""
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
Two-pass approach:
1. First pass: load 16 values, compute amax + FP8 scale
2. Second pass: reload 16 values, quantize to FP4
Uses threshold rounding (no float-to-int conversion).
Simplified: uses Float32 input to avoid BF16 scalar load issues.
Two-pass: (1) compute amax+scale, (2) quantize+pack.
"""
import torch
@@ -27,22 +24,21 @@ from dsv4.kernels.gemm.fp4_quant import (
@cute.kernel
def fp4_quant_test_kernel(
input_bf16: cute.Tensor, # (16,) BF16
out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
input_f32: cute.Tensor, # (16,) Float32
out_fp4: cute.Tensor, # (8,) Int32
out_sf: cute.Tensor, # (1,) Int32
gs_scalar: cute.Tensor, # (1,) Float32
):
tidx, _, _ = cute.arch.thread_idx()
if tidx == cutlass.Int32(0):
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
# Pass 1: amax
# Pass 1: amax (load Float32, divide by global_scale)
amax = cutlass.Float32(0.0)
for i in cutlass.range(16, unroll=1):
ptr = input_bf16.iterator + i * cutlass.Int32(2)
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
v = bf16_val.to(cutlass.Float32) / gs
ptr = input_f32.iterator + i * cutlass.Int32(4) # Float32 = 4 bytes
v = cute.arch.load(ptr, cutlass.Float32) / gs
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
amax = cute.arch.fmax(amax, a)
@@ -59,14 +55,12 @@ def fp4_quant_test_kernel(
# Pass 2: quantize and pack
for i in cutlass.range(8, unroll=1):
ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2)
v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16)
v0 = v0_bf16.to(cutlass.Float32) / gs
ptr0 = input_f32.iterator + (2 * i) * cutlass.Int32(4)
v0 = cute.arch.load(ptr0, cutlass.Float32) / gs
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2)
v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16)
v1 = v1_bf16.to(cutlass.Float32) / gs
ptr1 = input_f32.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(4)
v1 = cute.arch.load(ptr1, cutlass.Float32) / gs
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
packed = (nibble1 << cutlass.Int32(4)) | nibble0
@@ -84,6 +78,10 @@ def run_test():
amax_val = x_f32.abs().max().item()
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
# Pre-divide by global_scale (kernel expects normalized input)
x_norm = (x_f32 / global_scale).reshape(N).contiguous()
# Python reference (on the same normalized values)
ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale)
ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu()
ref_sf_bytes = ref_sf.view(torch.uint8).cpu()
@@ -92,16 +90,16 @@ def run_test():
print(f"Ref FP4: {ref_fp4_bytes}")
print(f"Ref SF: {ref_sf_bytes}")
# Output tensors
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device)
gs_tensor = torch.tensor([1.0], dtype=torch.float32, device=device) # already normalized
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
x_flat = x_bf16.reshape(N).contiguous()
input_c = to_cute(x_flat)
input_c = to_cute(x_norm)
out_fp4_c = to_cute(out_fp4)
out_sf_c = to_cute(out_sf)
gs_c = to_cute(gs_tensor)
@@ -141,7 +139,7 @@ def run_test():
if __name__ == "__main__":
print("=" * 60)
print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test")
print("Threshold rounding — no float-to-int conversion")
print("Threshold rounding — Float32 input — no BF16 scalar loads")
print("=" * 60)
success = run_test()
exit(0 if success else 1)