158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
"""
|
|
NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel.
|
|
|
|
Two-pass approach to avoid Python list indexing with CuTeDSL loop variables:
|
|
1. First pass: load 16 values, compute amax + FP8 scale
|
|
2. Second pass: reload 16 values, quantize to FP4
|
|
|
|
cute.arch.store confirmed available on B200.
|
|
"""
|
|
|
|
import torch
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
|
|
|
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
|
from dsv4.kernels.gemm.fp4_quant import (
|
|
fp8_e4m3_from_float32,
|
|
fp8_e4m3_to_float32,
|
|
quantize_e2m1_nibble,
|
|
)
|
|
|
|
|
|
@cute.kernel
|
|
def fp4_quant_test_kernel(
|
|
input_bf16: cute.Tensor, # (16,) BF16
|
|
out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes
|
|
out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte
|
|
gs_scalar: cute.Tensor, # (1,) Float32 — global scale
|
|
):
|
|
"""Quantize 16 BF16 values to NVFP4.
|
|
|
|
Two-pass: (1) compute amax+scale, (2) quantize+pack.
|
|
Only thread 0 works. Grid: (1,1,1), Block: (32,1,1).
|
|
"""
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
|
|
if tidx == cutlass.Int32(0):
|
|
gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32)
|
|
|
|
# ── Pass 1: Compute per-16-element amax ──
|
|
amax = cutlass.Float32(0.0)
|
|
for i in cutlass.range(16, unroll=1):
|
|
ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes
|
|
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
|
|
v = bf16_val.to(cutlass.Float32) / gs
|
|
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
|
|
amax = cute.arch.fmax(amax, a)
|
|
|
|
# Block scale
|
|
bsf_f32 = amax / cutlass.Float32(6.0)
|
|
if amax < cutlass.Float32(6.0 * (2.0 ** -9)):
|
|
bsf_f32 = cutlass.Float32(0.0)
|
|
|
|
# FP8 E4M3 cast + dequant
|
|
sf_bits = fp8_e4m3_from_float32(bsf_f32)
|
|
bs_dequant = fp8_e4m3_to_float32(sf_bits)
|
|
|
|
# Write FP8 scale byte (Int32 holding uint8)
|
|
cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32)
|
|
|
|
# ── Pass 2: Quantize and pack ──
|
|
for i in cutlass.range(8, unroll=1):
|
|
# Load even element (2*i)
|
|
ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2)
|
|
v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16)
|
|
v0 = v0_bf16.to(cutlass.Float32) / gs
|
|
nibble0 = quantize_e2m1_nibble(v0, bs_dequant)
|
|
|
|
# Load odd element (2*i+1)
|
|
ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2)
|
|
v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16)
|
|
v1 = v1_bf16.to(cutlass.Float32) / gs
|
|
nibble1 = quantize_e2m1_nibble(v1, bs_dequant)
|
|
|
|
packed = (nibble1 << cutlass.Int32(4)) | nibble0
|
|
|
|
# Write packed byte as Int32
|
|
cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32)
|
|
|
|
|
|
def run_test():
|
|
"""Run the FP4 quantization test on GPU."""
|
|
device = "cuda"
|
|
N = 16
|
|
|
|
torch.manual_seed(42)
|
|
x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device)
|
|
|
|
x_f32 = x_bf16.float()
|
|
amax_val = x_f32.abs().max().item()
|
|
global_scale = max(amax_val / (6.0 * 448.0), 1e-8)
|
|
|
|
ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale)
|
|
ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu()
|
|
ref_sf_bytes = ref_sf.view(torch.uint8).cpu()
|
|
|
|
print(f"Global scale: {global_scale:.8f}")
|
|
print(f"Ref FP4: {ref_fp4_bytes}")
|
|
print(f"Ref SF: {ref_sf_bytes}")
|
|
|
|
out_fp4 = torch.zeros(8, dtype=torch.int32, device=device)
|
|
out_sf = torch.zeros(1, dtype=torch.int32, device=device)
|
|
gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device)
|
|
|
|
def to_cute(t):
|
|
ct = cutlass_torch.from_dlpack(t)
|
|
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
|
|
|
x_flat = x_bf16.reshape(N).contiguous()
|
|
input_c = to_cute(x_flat)
|
|
out_fp4_c = to_cute(out_fp4)
|
|
out_sf_c = to_cute(out_sf)
|
|
gs_c = to_cute(gs_tensor)
|
|
|
|
print("\nCompiling kernel...")
|
|
try:
|
|
compiled = cute.compile(
|
|
fp4_quant_test_kernel,
|
|
input_c, out_fp4_c, out_sf_c, gs_c,
|
|
)
|
|
print("Compiled. Running...")
|
|
compiled(input_c, out_fp4_c, out_sf_c, gs_c)
|
|
torch.cuda.synchronize()
|
|
|
|
our_fp4 = out_fp4[:8].to(torch.uint8).cpu()
|
|
our_sf = out_sf[0].to(torch.uint8).cpu().item()
|
|
|
|
print(f"Our FP4: {our_fp4}")
|
|
print(f"Our SF: {our_sf}")
|
|
|
|
fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8])
|
|
sf_match = our_sf == ref_sf_bytes[0].item()
|
|
|
|
if fp4_match and sf_match:
|
|
print("\n✅ PASS!")
|
|
return True
|
|
else:
|
|
print(f"\n❌ FAIL: FP4={fp4_match} SF={sf_match}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"\n❌ ERROR: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=" * 60)
|
|
print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test")
|
|
print("=" * 60)
|
|
success = run_test()
|
|
exit(0 if success else 1)
|