diff --git a/tests/unit/test_nvfp4_quant_kernel.py b/tests/unit/test_nvfp4_quant_kernel.py new file mode 100644 index 00000000..4370b989 --- /dev/null +++ b/tests/unit/test_nvfp4_quant_kernel.py @@ -0,0 +1,176 @@ +""" +NVFP4-1.1: BF16→FP4 quantization kernel (CuTeDSL, Blackwell SM100). + +Reads BF16 from GMEM, quantizes to NVFP4, writes FP4 + FP8 scales to GMEM. +Uses TMA for efficient GMEM access. + +Grid: (num_rows, 1, 1) — 1 CTA per row. +Each CTA processes one row, with 128 threads each handling multiple 16-element blocks. + +Step 2 of the SwiGLU FP4 fusion. +Step 1: ✅ Python round-trip (cos 0.981) +Step 2: THIS — Standalone kernel +Step 3: Fuse into SwiGLU epilogue + +Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quant_kernel.py +""" +import torch +import math +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr +import cuda.bindings.driver as cuda +import cutlass.torch as ct + +from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE + + +class Nvfp4QuantizeKernel: + def __init__(self, M, N, block_size=16): + self.M = M + self.N = N + self.block_size = block_size + + @cute.jit + def __call__(self, x_bf16, x_sf, stream): + """ + x_bf16: (M, N) BF16 input — also used as FP4 output (in-place, same memory) + x_sf: (M, N // 16) FP8 E4M3 scale factors + """ + M = self.M; N = self.N; bs = self.block_size + self._kernel(x_bf16, x_sf, Int32(M), Int32(N), Int32(bs)).launch( + grid=(M, 1, 1), block=[128, 1, 1], stream=stream + ) + + @cute.kernel + def _kernel(self, x_bf16, x_sf, M, N, block_size): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + row = bidx + + n_blocks = N // block_size # number of 16-element blocks per row + threads = Int32(128) + blocks_per_thread = n_blocks // threads # blocks handled by each thread + + # Each thread processes blocks_per_thread consecutive blocks + for b in range(blocks_per_thread): + block_idx = tidx * blocks_per_thread + b + col_start = block_idx * block_size + + # Step 1: Read 16 BF16 elements and compute amax + amax = Float32(0.0) + vals = [None] * 16 # Will store BF16 values for later quantization + for i in range(block_size): + # Direct GMEM read (not TMA — simpler for first implementation) + val = x_bf16[row, col_start + i] + abs_val = val * val # val^2 — we need |val| + # Actually, we need max(|val|). Let me use a simpler approach. + # CuTeDSL doesn't have abs() as a primitive. + # Use: abs_val = val if val > 0 else -val + abs_val = val if val > Float32(0.0) else Float32(0.0) - val + amax = amax if amax > abs_val else abs_val + + # Step 2: Compute FP8 E4M3 scale = (amax / 6.0) + # For now, store as FP32 (FP8 cast is complex in CuTeDSL) + scale = amax / Float32(6.0) if amax > Float32(0.0) else Float32(1.0) + x_sf[row, block_idx] = scale + + # Step 3: Quantize each BF16 element to FP4 and pack + packed = Int32(0) + for i in range(block_size): + val = x_bf16[row, col_start + i] + # Scale + scaled = val / scale + # Abs + abs_scaled = scaled if scaled > Float32(0.0) else Float32(0.0) - scaled + # Half-step: round(|scaled| * 2) + half_step_raw = abs_scaled * Float32(2.0) + # Round: floor(x + 0.5) + half_step = half_step_raw + Float32(0.5) + # Clamp to [0, 12] + half_step = half_step if half_step > Float32(0.0) else Float32(0.0) + half_step = half_step if half_step < Float32(12.0) else Float32(12.0) + # Convert to int and map to FP4 index + hs_int = Int32(half_step) + # LUT: {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7} + # half_step is already quantized to even values 0,2,...,12 + fp4_idx = hs_int // Int32(2) + fp4_idx = fp4_idx if fp4_idx < Int32(7) else Int32(6) + + # Sign + sign = Int32(1) if val < Float32(0.0) else Int32(0) + nibble = fp4_idx | (sign << Int32(3)) + + # Pack: even elements in lower nibble, odd in upper + if i % 2 == 0: + packed = nibble + else: + packed = packed | (nibble << Int32(4)) + # Write the packed byte + # For float4_e2m1fn_x2 output, we'd use a proper TMA store + # For now, this is the quantization logic verification + + # Store FP4 packed data (simplified — not using TMA yet) + # This would need a proper GMEM write path + + +def dequantize_nvfp4_simple(x_fp4, block_scale, global_scale, N): + """Simple dequantize for verification.""" + M = x_fp4.shape[0] + block_size = SF_VEC_SIZE + + raw = x_fp4.view(torch.uint8) + even = raw & 0x0F + odd = (raw >> 4) & 0x0F + indices = torch.stack([even, odd], dim=-1).reshape(M, N) + signs = (indices >= 8).float() * -2 + 1 + mag = indices % 8 + + idx_to_val = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0], + dtype=torch.float32, device='cuda') + vals = idx_to_val[mag.long()] + x_deq = signs * vals + + block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float() + x_deq = x_deq * block_scale_exp * global_scale + return x_deq.to(torch.bfloat16) + + +def test_nvfp4_python(): + """Verify Python NVFP4 quantization round-trip.""" + print("\n=== NVFP4 Python Round-Trip ===") + torch.manual_seed(42) + M, N = 128, 512 + x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') + x_fp4, sf = quantize_activation_nvfp4(x, 1.0) + x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N) + cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item() + print(f" Round-trip cos: {cos:.6f} ({'PASS' if cos >= 0.95 else 'FAIL'})") + assert cos >= 0.95, f"Round-trip cosine too low: {cos}" + + +def test_nvfp4_kernel_launch(): + """Test the CuTeDSL quantization kernel (basic launch, not full quantization yet).""" + print("\n=== NVFP4 Kernel Launch Test ===") + print(" (Kernel implementation in progress — CuTeDSL quantization needs TMA + FP4 packing)") + print(" Current status: quantization logic designed, need to add TMA store for FP4 output") + + # For now, verify that the Python quantization matches our dequantize + torch.manual_seed(42) + M, N = 4, 64 + x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') + x_fp4, sf = quantize_activation_nvfp4(x, 1.0) + x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N) + cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item() + print(f" Small test cos: {cos:.6f}") + + +def test(): + print("=== NVFP4-1.1: BF16→FP4 Quantization ===") + test_nvfp4_python() + test_nvfp4_kernel_launch() + + +if __name__ == '__main__': + test()