diff --git a/tests/unit/test_nvfp4_quantize_kernel.py b/tests/unit/test_nvfp4_quantize_kernel.py new file mode 100644 index 00000000..56342ead --- /dev/null +++ b/tests/unit/test_nvfp4_quantize_kernel.py @@ -0,0 +1,190 @@ +""" +NVFP4-1.1 Step 1: Standalone BF16→FP4 quantization kernel. + +This kernel reads BF16 input from GMEM, quantizes to NVFP4 format +(per-16-element FP8 E4M3 scale + FP4 packed data), and writes to GMEM. + +This is the quantization logic that will be fused into the SwiGLU epilogue +once verified correct. + +NVFP4 quantization: + 1. For each 16-element microblock: compute amax + 2. block_scale = (amax / 6.0).to(float8_e4m3fn) [FP8 E4M3] + 3. For each element: x_scaled = x / block_scale + 4. half_steps = round(|x_scaled| * 2).clamp(0, 12) [13 quantization levels] + 5. Lookup FP4 index from half_steps + 6. Pack pairs of FP4 nibbles into bytes (float4_e2m1fn_x2) + +Reference: dsv4/ops/quantize.py (quantize_activation_nvfp4) + +Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quantize_kernel.py +""" +import torch +import math +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr +from cutlass.utils import LayoutEnum +import cuda.bindings.driver as cuda +import cutlass.torch as ct + +from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE + + +class Fp4QuantizeKernel: + """Standalone BF16→FP4 quantization kernel (CuTeDSL). + + Reads BF16 from GMEM, quantizes to NVFP4 (FP4 data + FP8 E4M3 scales), + writes both to GMEM. + """ + def __init__(self, M, N, block_size=16, global_scale=1.0): + self.M = M + self.N = N + self.block_size = block_size # 16 for NVFP4 + self.global_scale = global_scale + self.n_blocks = N // block_size # number of 16-element blocks per row + + @cute.jit + def __call__(self, x_bf16, x_fp4, x_sf, stream): + """ + x_bf16: (M, N) BF16 input + x_fp4: (M, N // 2) float4_e2m1fn_x2 output (packed FP4) + x_sf: (M, N // 16) float8_e4m3fn output (FP8 E4M3 scales) + """ + M = self.M; N = self.N; block_size = self.block_size + n_blocks = self.n_blocks + + # Each CTA processes 1 row (for simplicity) + # Grid: (n_blocks_per_row, M, 1) — each CTA quantizes 1 block in 1 row + # Actually, let's do it differently: each CTA processes 1 row's worth of blocks + # Grid: (1, M, 1) — 1 CTA per row + + # For now, let's use a simple 1-CTA approach (1 row at a time) + # and verify correctness before optimizing. + + # Actually, this needs to be a proper kernel. Let me think about + # the thread mapping. + # + # For a (M, N) input with N=4096, block_size=16: + # - n_blocks = 256 per row + # - Each block: 16 BF16 elements → 8 FP4 bytes + 1 FP8 scale byte + # - Total FP4 output: M × 2048 bytes + # - Total SF output: M × 256 bytes + # + # Thread mapping: 128 threads per CTA, each thread handles 2 blocks + # (2 × 16 = 32 elements). With N=4096, n_blocks=256, need 128 threads × 2 = 256 blocks. + + # For now, use the PyTorch reference as the ground truth and just verify. + # The kernel implementation will be added in the next step. + pass + + +def test_nvfp4_quantize_correctness(): + """Verify that our Python NVFP4 quantization matches the reference.""" + print("\n=== NVFP4 Quantization Correctness Test ===") + + torch.manual_seed(42) + M, N = 128, 512 + + x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') + global_scale = 1.0 + + x_fp4, block_scale, _ = quantize_activation_nvfp4(x, global_scale) + + print(f" Input: ({M}, {N}) BF16") + print(f" FP4 output: {x_fp4.shape}, dtype={x_fp4.dtype}") + print(f" Scale output: {block_scale.shape}, dtype={block_scale.dtype}") + print(f" SF_VEC_SIZE: {SF_VEC_SIZE}") + + # Dequantize to verify round-trip + # FP4 → BF16: x_bf16 ≈ x_fp4 * block_scale * global_scale + # But we don't have a dequantize function yet. + # Let's just verify the shapes and dtypes are correct. + + assert x_fp4.dtype == torch.float4_e2m1fn_x2, f"Expected float4_e2m1fn_x2, got {x_fp4.dtype}" + assert block_scale.dtype == torch.float8_e4m3fn, f"Expected float8_e4m3fn, got {block_scale.dtype}" + assert x_fp4.shape == (M, N // 2), f"Expected ({M}, {N//2}), got {x_fp4.shape}" + assert block_scale.shape == (M, N // SF_VEC_SIZE), f"Expected ({M}, {N//SF_VEC_SIZE}), got {block_scale.shape}" + + print(" ✅ Shapes and dtypes correct") + + # Verify that the quantized values round-trip correctly + # by dequantizing and comparing + # We need to manually dequantize since there's no built-in function + x_deq = dequantize_nvfp4(x_fp4, block_scale, global_scale, N) + + # Compare with original (FP4 has limited precision) + cos = torch.nn.functional.cosine_similarity( + x.flatten().unsqueeze(0).float(), + x_deq.flatten().unsqueeze(0).float() + ).item() + + print(f" Dequantized cosine similarity: {cos:.6f}") + print(f" FP4 round-trip: {'PASS' if cos >= 0.95 else 'FAIL (expected ~0.97-0.99 for FP4)'}") + + # Also verify max relative error + mask = x.abs() > 0.01 # Skip near-zero values + rel_err = ((x[mask] - x_deq[mask]).abs() / x[mask].abs()).mean().item() + print(f" Mean relative error: {rel_err:.4f}") + + +def dequantize_nvfp4(x_fp4, block_scale, global_scale, N): + """Dequantize NVFP4 back to BF16 for verification. + + This is the inverse of quantize_activation_nvfp4. + """ + M = x_fp4.shape[0] + block_size = SF_VEC_SIZE + + # Unpack FP4 nibbles to indices + raw = x_fp4.view(torch.uint8) # (M, N//2) bytes + even_nibbles = raw & 0x0F # Lower nibble + odd_nibbles = (raw >> 4) & 0x0F # Upper nibble + + # Interleave even and odd + indices = torch.stack([even_nibbles, odd_nibbles], dim=-1).reshape(M, N) + + # Extract sign and magnitude + signs = (indices >= 8).float() * -2 + 1 # +1 for 0-7, -1 for 8-15 + mag_indices = indices % 8 # Magnitude index 0-7 + + # FP4 E2M1 values: 0, 2, 3, 4, 6, 8, 12, inf (but in our quantization, + # the step LUT maps half_steps to specific values) + # For simplicity, use the dequantization from the step LUT + # The E2M1 format: (-1)^sign × 2^exp × (1 + mantissa/2) + # exp = (mag >> 1), mantissa = mag & 1 + # value = (-1)^sign × 2^exp × (1 + (mag & 1) * 0.5) + # But this doesn't match the NVFP4 quantization exactly. + # Let me use a different approach: compute from the quantize step LUT. + + # Actually, let me just do: x_deq = indices_to_values * block_scale * global_scale + # But I don't have the step LUT in the kernel. Let me compute it directly. + + # The step LUT in quantize.py: + # step_to_idx = {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7} + # inverse: idx_to_half_step = {0:0, 1:2, 2:4, 3:6, 4:8, 5:10, 6:12, 7:14} + # The dequantized value = sign * half_step / 2.0 = sign * (idx_to_half_step[mag] / 2.0) + + idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14], + dtype=torch.float32, device='cuda') + half_steps = idx_to_half_step[mag_indices.long()] # (M, N) + + # Dequantized: value = sign * half_step / 2.0 + x_deq_fp32 = signs * half_steps / 2.0 # (M, N) + + # Apply block scale and global scale + # block_scale shape: (M, N//16) → expand to (M, N) + block_scale_expanded = block_scale.repeat_interleave(block_size, dim=-1).float() + x_deq_fp32 = x_deq_fp32 * block_scale_expanded * global_scale + + return x_deq_fp32.to(torch.bfloat16) + + +def test(): + print("=== NVFP4-1.1: BF16→FP4 Quantization Kernel Development ===") + test_nvfp4_quantize_correctness() + + +if __name__ == '__main__': + test()