""" NVFP4-1.1 Step 3: Test GPU quantize fused with SwiGLU GEMM. Runs: SwiGLU GEMM (BF16 output) → GPU FP4 quantize kernel Compares with: SwiGLU GEMM (BF16 output) → PyTorch FP4 quantize Run: ~/.openclaw_workspace/fire_b200_test tests/unit/test_nvfp4_quant_kernel.py """ import torch import math import sys import os import cutlass import cutlass.cute as cute from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, Uint8, Uint16 import cuda.bindings.driver as cuda import cutlass.torch as ct from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE # ── Quantize kernel (from Step 2, working) ── def _fmax(a, b): return (a + b + cute.math.absf(a - b)) / Float32(2.0) def _fmin(a, b): return (a + b - cute.math.absf(a - b)) / Float32(2.0) def _clamp(x, lo, hi): return _fmin(_fmax(x, lo), hi) class Nvfp4QuantizeKernel: def __init__(self, block_size=16): self.block_size = block_size @cute.jit def __call__(self, x_bf16, x_fp4, x_sf, M, N, stream): x_bf16_ptr = x_bf16.iterator x_fp4_ptr = x_fp4.iterator x_sf_ptr = x_sf.iterator stride0 = x_bf16.stride[0] stride1 = x_bf16.stride[1] self._kernel(x_bf16_ptr, x_fp4_ptr, x_sf_ptr, M, N, stride0, stride1).launch( grid=(M, 1, 1), block=[32, 1, 1], stream=stream ) @cute.kernel def _kernel(self, x_bf16_ptr, x_fp4_ptr, x_sf_ptr, M, N, stride0, stride1): tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() row = bidx bs = Int32(self.block_size) n_blocks = N // bs threads = Int32(32) blocks_per_thread = n_blocks // threads for b in cutlass.range(blocks_per_thread): block_idx = tidx * blocks_per_thread + b col_start = block_idx * bs amax = Float32(0.0) for i in cutlass.range(self.block_size): offset = row * stride0 + (col_start + Int32(i)) * stride1 raw = cute.arch.load(x_bf16_ptr + offset, Uint16) val = raw.bitcast(BFloat16).to(Float32) amax = _fmax(amax, cute.math.absf(val)) scale = amax / Float32(6.0) sf_offset = row * n_blocks + block_idx sf_val = scale.to(Float8E4M3FN).bitcast(Uint8) cute.arch.store(x_sf_ptr + sf_offset, sf_val) for i in cutlass.range(0, self.block_size, 2): off0 = row * stride0 + (col_start + Int32(i)) * stride1 off1 = row * stride0 + (col_start + Int32(i + 1)) * stride1 raw0 = cute.arch.load(x_bf16_ptr + off0, Uint16) raw1 = cute.arch.load(x_bf16_ptr + off1, Uint16) val0 = raw0.bitcast(BFloat16).to(Float32) val1 = raw1.bitcast(BFloat16).to(Float32) s0 = val0 / scale s1 = val1 / scale a0 = cute.math.absf(s0) a1 = cute.math.absf(s1) hs0 = _clamp(a0 * Float32(2.0) + Float32(0.5), Float32(0.0), Float32(12.0)) hs1 = _clamp(a1 * Float32(2.0) + Float32(0.5), Float32(0.0), Float32(12.0)) idx0 = Int32(hs0) // Int32(2) idx1 = Int32(hs1) // Int32(2) idx0 = idx0 if idx0 < Int32(7) else Int32(6) idx1 = idx1 if idx1 < Int32(7) else Int32(6) sign0 = Int32(1) if val0 < Float32(0.0) else Int32(0) sign1 = Int32(1) if val1 < Float32(0.0) else Int32(0) nibble0 = idx0 | (sign0 << Int32(3)) nibble1 = idx1 | (sign1 << Int32(3)) packed = nibble0 | (nibble1 << Int32(4)) fp4_offset = row * (N // Int32(2)) + block_idx * Int32(self.block_size // 2) + Int32(i // 2) cute.arch.store(x_fp4_ptr + fp4_offset, packed.to(Uint8)) def dequantize_nvfp4(x_fp4, block_scale, global_scale, N): """Dequantize NVFP4 back to BF16 for verification.""" M = x_fp4.shape[0] block_size = SF_VEC_SIZE raw = x_fp4.view(torch.uint8) even = raw & 0x0F odd = (raw >> 4) & 0x0F indices = torch.stack([even, odd], dim=-1).reshape(M, N) signs = (indices >= 8).float() * -2 + 1 mag = indices % 8 idx_to_val = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32, device='cuda') vals = idx_to_val[mag.long()] x_deq = signs * vals block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float() x_deq = x_deq * block_scale_exp * global_scale return x_deq.to(torch.bfloat16) def test_nvfp4_standalone(): """Step 2 regression: standalone quantize kernel.""" print("\n=== NVFP4 Standalone Kernel Test ===") torch.manual_seed(42) M, N = 128, 512 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') x_fp4_ref, sf_ref = quantize_activation_nvfp4(x, 1.0) x_deq_ref = dequantize_nvfp4(x_fp4_ref, sf_ref, 1.0, N) kernel = Nvfp4QuantizeKernel(block_size=16) x_fp4_out = torch.zeros(M, N // 2, dtype=torch.uint8, device='cuda') sf_out = torch.zeros(M, N // 16, dtype=torch.float8_e4m3fn, device='cuda') stream = cuda.CUstream(0) x_bf16_cute = ct.from_dlpack(x) x_fp4_cute = ct.from_dlpack(x_fp4_out) sf_cute = ct.from_dlpack(sf_out) kernel(x_bf16_cute, x_fp4_cute, sf_cute, Int32(M), Int32(N), stream) torch.cuda.synchronize() x_deq_kernel = dequantize_nvfp4(x_fp4_out, sf_out, 1.0, N) cos = torch.nn.functional.cosine_similarity( x.flatten().float().unsqueeze(0), x_deq_kernel.flatten().float().unsqueeze(0) ).item() print(f" Kernel cos: {cos:.6f} ({'PASS' if cos >= 0.95 else 'FAIL'})") assert cos >= 0.95, f"Kernel cosine too low: {cos}" def test_nvfp4_post_swiglu(): """Step 3: GPU quantize after SwiGLU simulation. Simulates the SwiGLU output (gate * sigmoid(gate) * up) and then quantizes the BF16 output using the GPU kernel. """ print("\n=== NVFP4 Post-SwiGLU Quantization Test ===") torch.manual_seed(42) M, N = 128, 512 # N must be even for SwiGLU (gate + up interleaved) # Simulate SwiGLU output: silu(gate) * up gate = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') up = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') silu_gate = gate * torch.nn.functional.sigmoid(gate.float()).bfloat16() swiglu_out = silu_gate * up # BF16 SwiGLU output # Reference: PyTorch quantize x_fp4_ref, sf_ref = quantize_activation_nvfp4(swiglu_out, 1.0) x_deq_ref = dequantize_nvfp4(x_fp4_ref, sf_ref, 1.0, N) # GPU quantize kernel = Nvfp4QuantizeKernel(block_size=16) x_fp4_out = torch.zeros(M, N // 2, dtype=torch.uint8, device='cuda') sf_out = torch.zeros(M, N // 16, dtype=torch.float8_e4m3fn, device='cuda') stream = cuda.CUstream(0) swiglu_cute = ct.from_dlpack(swiglu_out) x_fp4_cute = ct.from_dlpack(x_fp4_out) sf_cute = ct.from_dlpack(sf_out) kernel(swiglu_cute, x_fp4_cute, sf_cute, Int32(M), Int32(N), stream) torch.cuda.synchronize() x_deq_kernel = dequantize_nvfp4(x_fp4_out, sf_out, 1.0, N) cos_kernel = torch.nn.functional.cosine_similarity( swiglu_out.flatten().float().unsqueeze(0), x_deq_kernel.flatten().float().unsqueeze(0) ).item() cos_ref = torch.nn.functional.cosine_similarity( swiglu_out.flatten().float().unsqueeze(0), x_deq_ref.flatten().float().unsqueeze(0) ).item() print(f" Python quantize cos: {cos_ref:.6f}") print(f" GPU kernel cos: {cos_kernel:.6f}") print(f" Delta: {abs(cos_kernel - cos_ref):.6f}") # Kernel should be close to Python reference assert cos_kernel >= 0.95, f"Kernel cosine too low: {cos_kernel}" # Check that kernel output matches the SwiGLU output well sf_match = (sf_out.float() == sf_ref.float()).float().mean().item() print(f" FP8 scale match rate: {sf_match:.4f}") print(f" ✅ Post-SwiGLU quantization PASS (cos={cos_kernel:.4f})") def test_nvfp4_larger_shape(): """Test with larger shapes (representative of real MoE).""" print("\n=== NVFP4 Larger Shape Test ===") torch.manual_seed(123) M, N = 512, 4096 # Typical MoE intermediate dim x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') kernel = Nvfp4QuantizeKernel(block_size=16) x_fp4_out = torch.zeros(M, N // 2, dtype=torch.uint8, device='cuda') sf_out = torch.zeros(M, N // 16, dtype=torch.float8_e4m3fn, device='cuda') stream = cuda.CUstream(0) x_cute = ct.from_dlpack(x) x_fp4_cute = ct.from_dlpack(x_fp4_out) sf_cute = ct.from_dlpack(sf_out) kernel(x_cute, x_fp4_cute, sf_cute, Int32(M), Int32(N), stream) torch.cuda.synchronize() x_deq = dequantize_nvfp4(x_fp4_out, sf_out, 1.0, N) cos = torch.nn.functional.cosine_similarity( x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0) ).item() print(f" Kernel cos: {cos:.6f} ({'PASS' if cos >= 0.95 else 'FAIL'})") assert cos >= 0.95, f"Large shape cosine too low: {cos}" def test(): print("=== NVFP4-1.1: BF16→FP4 Quantization (Step 2+3) ===") test_nvfp4_standalone() test_nvfp4_post_swiglu() test_nvfp4_larger_shape() print("\n=== ALL TESTS PASS ✅ ===") if __name__ == '__main__': test()