NVFP4-1.1: standalone BF16→FP4 quantize kernel (WIP) + dequantize verification
This commit is contained in:
176
tests/unit/test_nvfp4_quant_kernel.py
Normal file
176
tests/unit/test_nvfp4_quant_kernel.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
NVFP4-1.1: BF16→FP4 quantization kernel (CuTeDSL, Blackwell SM100).
|
||||
|
||||
Reads BF16 from GMEM, quantizes to NVFP4, writes FP4 + FP8 scales to GMEM.
|
||||
Uses TMA for efficient GMEM access.
|
||||
|
||||
Grid: (num_rows, 1, 1) — 1 CTA per row.
|
||||
Each CTA processes one row, with 128 threads each handling multiple 16-element blocks.
|
||||
|
||||
Step 2 of the SwiGLU FP4 fusion.
|
||||
Step 1: ✅ Python round-trip (cos 0.981)
|
||||
Step 2: THIS — Standalone kernel
|
||||
Step 3: Fuse into SwiGLU epilogue
|
||||
|
||||
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quant_kernel.py
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
||||
|
||||
|
||||
class Nvfp4QuantizeKernel:
|
||||
def __init__(self, M, N, block_size=16):
|
||||
self.M = M
|
||||
self.N = N
|
||||
self.block_size = block_size
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, x_bf16, x_sf, stream):
|
||||
"""
|
||||
x_bf16: (M, N) BF16 input — also used as FP4 output (in-place, same memory)
|
||||
x_sf: (M, N // 16) FP8 E4M3 scale factors
|
||||
"""
|
||||
M = self.M; N = self.N; bs = self.block_size
|
||||
self._kernel(x_bf16, x_sf, Int32(M), Int32(N), Int32(bs)).launch(
|
||||
grid=(M, 1, 1), block=[128, 1, 1], stream=stream
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, x_bf16, x_sf, M, N, block_size):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
row = bidx
|
||||
|
||||
n_blocks = N // block_size # number of 16-element blocks per row
|
||||
threads = Int32(128)
|
||||
blocks_per_thread = n_blocks // threads # blocks handled by each thread
|
||||
|
||||
# Each thread processes blocks_per_thread consecutive blocks
|
||||
for b in range(blocks_per_thread):
|
||||
block_idx = tidx * blocks_per_thread + b
|
||||
col_start = block_idx * block_size
|
||||
|
||||
# Step 1: Read 16 BF16 elements and compute amax
|
||||
amax = Float32(0.0)
|
||||
vals = [None] * 16 # Will store BF16 values for later quantization
|
||||
for i in range(block_size):
|
||||
# Direct GMEM read (not TMA — simpler for first implementation)
|
||||
val = x_bf16[row, col_start + i]
|
||||
abs_val = val * val # val^2 — we need |val|
|
||||
# Actually, we need max(|val|). Let me use a simpler approach.
|
||||
# CuTeDSL doesn't have abs() as a primitive.
|
||||
# Use: abs_val = val if val > 0 else -val
|
||||
abs_val = val if val > Float32(0.0) else Float32(0.0) - val
|
||||
amax = amax if amax > abs_val else abs_val
|
||||
|
||||
# Step 2: Compute FP8 E4M3 scale = (amax / 6.0)
|
||||
# For now, store as FP32 (FP8 cast is complex in CuTeDSL)
|
||||
scale = amax / Float32(6.0) if amax > Float32(0.0) else Float32(1.0)
|
||||
x_sf[row, block_idx] = scale
|
||||
|
||||
# Step 3: Quantize each BF16 element to FP4 and pack
|
||||
packed = Int32(0)
|
||||
for i in range(block_size):
|
||||
val = x_bf16[row, col_start + i]
|
||||
# Scale
|
||||
scaled = val / scale
|
||||
# Abs
|
||||
abs_scaled = scaled if scaled > Float32(0.0) else Float32(0.0) - scaled
|
||||
# Half-step: round(|scaled| * 2)
|
||||
half_step_raw = abs_scaled * Float32(2.0)
|
||||
# Round: floor(x + 0.5)
|
||||
half_step = half_step_raw + Float32(0.5)
|
||||
# Clamp to [0, 12]
|
||||
half_step = half_step if half_step > Float32(0.0) else Float32(0.0)
|
||||
half_step = half_step if half_step < Float32(12.0) else Float32(12.0)
|
||||
# Convert to int and map to FP4 index
|
||||
hs_int = Int32(half_step)
|
||||
# LUT: {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7}
|
||||
# half_step is already quantized to even values 0,2,...,12
|
||||
fp4_idx = hs_int // Int32(2)
|
||||
fp4_idx = fp4_idx if fp4_idx < Int32(7) else Int32(6)
|
||||
|
||||
# Sign
|
||||
sign = Int32(1) if val < Float32(0.0) else Int32(0)
|
||||
nibble = fp4_idx | (sign << Int32(3))
|
||||
|
||||
# Pack: even elements in lower nibble, odd in upper
|
||||
if i % 2 == 0:
|
||||
packed = nibble
|
||||
else:
|
||||
packed = packed | (nibble << Int32(4))
|
||||
# Write the packed byte
|
||||
# For float4_e2m1fn_x2 output, we'd use a proper TMA store
|
||||
# For now, this is the quantization logic verification
|
||||
|
||||
# Store FP4 packed data (simplified — not using TMA yet)
|
||||
# This would need a proper GMEM write path
|
||||
|
||||
|
||||
def dequantize_nvfp4_simple(x_fp4, block_scale, global_scale, N):
|
||||
"""Simple dequantize for verification."""
|
||||
M = x_fp4.shape[0]
|
||||
block_size = SF_VEC_SIZE
|
||||
|
||||
raw = x_fp4.view(torch.uint8)
|
||||
even = raw & 0x0F
|
||||
odd = (raw >> 4) & 0x0F
|
||||
indices = torch.stack([even, odd], dim=-1).reshape(M, N)
|
||||
signs = (indices >= 8).float() * -2 + 1
|
||||
mag = indices % 8
|
||||
|
||||
idx_to_val = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0],
|
||||
dtype=torch.float32, device='cuda')
|
||||
vals = idx_to_val[mag.long()]
|
||||
x_deq = signs * vals
|
||||
|
||||
block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float()
|
||||
x_deq = x_deq * block_scale_exp * global_scale
|
||||
return x_deq.to(torch.bfloat16)
|
||||
|
||||
|
||||
def test_nvfp4_python():
|
||||
"""Verify Python NVFP4 quantization round-trip."""
|
||||
print("\n=== NVFP4 Python Round-Trip ===")
|
||||
torch.manual_seed(42)
|
||||
M, N = 128, 512
|
||||
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
||||
x_fp4, sf = quantize_activation_nvfp4(x, 1.0)
|
||||
x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N)
|
||||
cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item()
|
||||
print(f" Round-trip cos: {cos:.6f} ({'PASS' if cos >= 0.95 else 'FAIL'})")
|
||||
assert cos >= 0.95, f"Round-trip cosine too low: {cos}"
|
||||
|
||||
|
||||
def test_nvfp4_kernel_launch():
|
||||
"""Test the CuTeDSL quantization kernel (basic launch, not full quantization yet)."""
|
||||
print("\n=== NVFP4 Kernel Launch Test ===")
|
||||
print(" (Kernel implementation in progress — CuTeDSL quantization needs TMA + FP4 packing)")
|
||||
print(" Current status: quantization logic designed, need to add TMA store for FP4 output")
|
||||
|
||||
# For now, verify that the Python quantization matches our dequantize
|
||||
torch.manual_seed(42)
|
||||
M, N = 4, 64
|
||||
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
||||
x_fp4, sf = quantize_activation_nvfp4(x, 1.0)
|
||||
x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N)
|
||||
cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item()
|
||||
print(f" Small test cos: {cos:.6f}")
|
||||
|
||||
|
||||
def test():
|
||||
print("=== NVFP4-1.1: BF16→FP4 Quantization ===")
|
||||
test_nvfp4_python()
|
||||
test_nvfp4_kernel_launch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user