NVFP4-1.1: test kernel uses Float32 input (avoids BF16 scalar load issue)
This commit is contained in:
@@ -1,11 +1,8 @@
|
||||
"""
|
||||
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
|
||||
|
||||
Two-pass approach:
|
||||
1. First pass: load 16 values, compute amax + FP8 scale
|
||||
2. Second pass: reload 16 values, quantize to FP4
|
||||
|
||||
Uses threshold rounding (no float-to-int conversion).
|
||||
Simplified: uses Float32 input to avoid BF16 scalar load issues.
|
||||
Two-pass: (1) compute amax+scale, (2) quantize+pack.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -27,22 +24,21 @@ from dsv4.kernels.gemm.fp4_quant import (
|
||||
|
||||
@cute.kernel
|
||||
def fp4_quant_test_kernel(
|
||||
input_bf16: cute.Tensor, # (16,) BF16
|
||||
out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes
|
||||
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
|
||||
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
|
||||
input_f32: cute.Tensor, # (16,) Float32
|
||||
out_fp4: cute.Tensor, # (8,) Int32
|
||||
out_sf: cute.Tensor, # (1,) Int32
|
||||
gs_scalar: cute.Tensor, # (1,) Float32
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
if tidx == cutlass.Int32(0):
|
||||
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
|
||||
|
||||
# Pass 1: amax
|
||||
# Pass 1: amax (load Float32, divide by global_scale)
|
||||
amax = cutlass.Float32(0.0)
|
||||
for i in cutlass.range(16, unroll=1):
|
||||
ptr = input_bf16.iterator + i * cutlass.Int32(2)
|
||||
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
|
||||
v = bf16_val.to(cutlass.Float32) / gs
|
||||
ptr = input_f32.iterator + i * cutlass.Int32(4) # Float32 = 4 bytes
|
||||
v = cute.arch.load(ptr, cutlass.Float32) / gs
|
||||
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
|
||||
amax = cute.arch.fmax(amax, a)
|
||||
|
||||
@@ -59,14 +55,12 @@ def fp4_quant_test_kernel(
|
||||
|
||||
# Pass 2: quantize and pack
|
||||
for i in cutlass.range(8, unroll=1):
|
||||
ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2)
|
||||
v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16)
|
||||
v0 = v0_bf16.to(cutlass.Float32) / gs
|
||||
ptr0 = input_f32.iterator + (2 * i) * cutlass.Int32(4)
|
||||
v0 = cute.arch.load(ptr0, cutlass.Float32) / gs
|
||||
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
|
||||
|
||||
ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2)
|
||||
v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16)
|
||||
v1 = v1_bf16.to(cutlass.Float32) / gs
|
||||
ptr1 = input_f32.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(4)
|
||||
v1 = cute.arch.load(ptr1, cutlass.Float32) / gs
|
||||
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
|
||||
|
||||
packed = (nibble1 << cutlass.Int32(4)) | nibble0
|
||||
@@ -84,6 +78,10 @@ def run_test():
|
||||
amax_val = x_f32.abs().max().item()
|
||||
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
|
||||
|
||||
# Pre-divide by global_scale (kernel expects normalized input)
|
||||
x_norm = (x_f32 / global_scale).reshape(N).contiguous()
|
||||
|
||||
# Python reference (on the same normalized values)
|
||||
ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale)
|
||||
ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu()
|
||||
ref_sf_bytes = ref_sf.view(torch.uint8).cpu()
|
||||
@@ -92,16 +90,16 @@ def run_test():
|
||||
print(f"Ref FP4: {ref_fp4_bytes}")
|
||||
print(f"Ref SF: {ref_sf_bytes}")
|
||||
|
||||
# Output tensors
|
||||
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
|
||||
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
|
||||
gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device)
|
||||
gs_tensor = torch.tensor([1.0], dtype=torch.float32, device=device) # already normalized
|
||||
|
||||
def to_cute(t):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
x_flat = x_bf16.reshape(N).contiguous()
|
||||
input_c = to_cute(x_flat)
|
||||
input_c = to_cute(x_norm)
|
||||
out_fp4_c = to_cute(out_fp4)
|
||||
out_sf_c = to_cute(out_sf)
|
||||
gs_c = to_cute(gs_tensor)
|
||||
@@ -141,7 +139,7 @@ def run_test():
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test")
|
||||
print("Threshold rounding — no float-to-int conversion")
|
||||
print("Threshold rounding — Float32 input — no BF16 scalar loads")
|
||||
print("=" * 60)
|
||||
success = run_test()
|
||||
exit(0 if success else 1)
|
||||
|
||||
Reference in New Issue
Block a user