""" NVFP4-1.1 Step 1: Standalone BF16→FP4 quantization kernel. This kernel reads BF16 input from GMEM, quantizes to NVFP4 format (per-16-element FP8 E4M3 scale + FP4 packed data), and writes to GMEM. This is the quantization logic that will be fused into the SwiGLU epilogue once verified correct. NVFP4 quantization: 1. For each 16-element microblock: compute amax 2. block_scale = (amax / 6.0).to(float8_e4m3fn) [FP8 E4M3] 3. For each element: x_scaled = x / block_scale 4. half_steps = round(|x_scaled| * 2).clamp(0, 12) [13 quantization levels] 5. Lookup FP4 index from half_steps 6. Pack pairs of FP4 nibbles into bytes (float4_e2m1fn_x2) Reference: dsv4/ops/quantize.py (quantize_activation_nvfp4) Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quantize_kernel.py """ import torch import math import cutlass import cutlass.cute as cute import cutlass.utils as utils from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr from cutlass.utils import LayoutEnum import cuda.bindings.driver as cuda import cutlass.torch as ct from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE class Fp4QuantizeKernel: """Standalone BF16→FP4 quantization kernel (CuTeDSL). Reads BF16 from GMEM, quantizes to NVFP4 (FP4 data + FP8 E4M3 scales), writes both to GMEM. """ def __init__(self, M, N, block_size=16, global_scale=1.0): self.M = M self.N = N self.block_size = block_size # 16 for NVFP4 self.global_scale = global_scale self.n_blocks = N // block_size # number of 16-element blocks per row @cute.jit def __call__(self, x_bf16, x_fp4, x_sf, stream): """ x_bf16: (M, N) BF16 input x_fp4: (M, N // 2) float4_e2m1fn_x2 output (packed FP4) x_sf: (M, N // 16) float8_e4m3fn output (FP8 E4M3 scales) """ M = self.M; N = self.N; block_size = self.block_size n_blocks = self.n_blocks # Each CTA processes 1 row (for simplicity) # Grid: (n_blocks_per_row, M, 1) — each CTA quantizes 1 block in 1 row # Actually, let's do it differently: each CTA processes 1 row's worth of blocks # Grid: (1, M, 1) — 1 CTA per row # For now, let's use a simple 1-CTA approach (1 row at a time) # and verify correctness before optimizing. # Actually, this needs to be a proper kernel. Let me think about # the thread mapping. # # For a (M, N) input with N=4096, block_size=16: # - n_blocks = 256 per row # - Each block: 16 BF16 elements → 8 FP4 bytes + 1 FP8 scale byte # - Total FP4 output: M × 2048 bytes # - Total SF output: M × 256 bytes # # Thread mapping: 128 threads per CTA, each thread handles 2 blocks # (2 × 16 = 32 elements). With N=4096, n_blocks=256, need 128 threads × 2 = 256 blocks. # For now, use the PyTorch reference as the ground truth and just verify. # The kernel implementation will be added in the next step. pass def test_nvfp4_quantize_correctness(): """Verify that our Python NVFP4 quantization matches the reference.""" print("\n=== NVFP4 Quantization Correctness Test ===") torch.manual_seed(42) M, N = 128, 512 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') global_scale = 1.0 x_fp4, block_scale = quantize_activation_nvfp4(x, global_scale) print(f" Input: ({M}, {N}) BF16") print(f" FP4 output: {x_fp4.shape}, dtype={x_fp4.dtype}") print(f" Scale output: {block_scale.shape}, dtype={block_scale.dtype}") print(f" SF_VEC_SIZE: {SF_VEC_SIZE}") # Dequantize to verify round-trip # FP4 → BF16: x_bf16 ≈ x_fp4 * block_scale * global_scale # But we don't have a dequantize function yet. # Let's just verify the shapes and dtypes are correct. assert x_fp4.dtype == torch.float4_e2m1fn_x2, f"Expected float4_e2m1fn_x2, got {x_fp4.dtype}" assert block_scale.dtype == torch.float8_e4m3fn, f"Expected float8_e4m3fn, got {block_scale.dtype}" assert x_fp4.shape == (M, N // 2), f"Expected ({M}, {N//2}), got {x_fp4.shape}" assert block_scale.shape == (M, N // SF_VEC_SIZE), f"Expected ({M}, {N//SF_VEC_SIZE}), got {block_scale.shape}" print(" ✅ Shapes and dtypes correct") # Verify that the quantized values round-trip correctly # by dequantizing and comparing # We need to manually dequantize since there's no built-in function x_deq = dequantize_nvfp4(x_fp4, block_scale, global_scale, N) # Compare with original (FP4 has limited precision) cos = torch.nn.functional.cosine_similarity( x.flatten().unsqueeze(0).float(), x_deq.flatten().unsqueeze(0).float() ).item() print(f" Dequantized cosine similarity: {cos:.6f}") print(f" FP4 round-trip: {'PASS' if cos >= 0.95 else 'FAIL (expected ~0.97-0.99 for FP4)'}") # Also verify max relative error mask = x.abs() > 0.01 # Skip near-zero values rel_err = ((x[mask] - x_deq[mask]).abs() / x[mask].abs()).mean().item() print(f" Mean relative error: {rel_err:.4f}") def dequantize_nvfp4(x_fp4, block_scale, global_scale, N): """Dequantize NVFP4 back to BF16 for verification. This is the inverse of quantize_activation_nvfp4. """ M = x_fp4.shape[0] block_size = SF_VEC_SIZE # Unpack FP4 nibbles to indices raw = x_fp4.view(torch.uint8) # (M, N//2) bytes even_nibbles = raw & 0x0F # Lower nibble odd_nibbles = (raw >> 4) & 0x0F # Upper nibble # Interleave even and odd indices = torch.stack([even_nibbles, odd_nibbles], dim=-1).reshape(M, N) # Extract sign and magnitude signs = (indices >= 8).float() * -2 + 1 # +1 for 0-7, -1 for 8-15 mag_indices = indices % 8 # Magnitude index 0-7 # FP4 E2M1 values: 0, 2, 3, 4, 6, 8, 12, inf (but in our quantization, # the step LUT maps half_steps to specific values) # For simplicity, use the dequantization from the step LUT # The E2M1 format: (-1)^sign × 2^exp × (1 + mantissa/2) # exp = (mag >> 1), mantissa = mag & 1 # value = (-1)^sign × 2^exp × (1 + (mag & 1) * 0.5) # But this doesn't match the NVFP4 quantization exactly. # Let me use a different approach: compute from the quantize step LUT. # Actually, let me just do: x_deq = indices_to_values * block_scale * global_scale # But I don't have the step LUT in the kernel. Let me compute it directly. # The step LUT in quantize.py: # step_to_idx = {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7} # inverse: idx_to_half_step = {0:0, 1:2, 2:4, 3:6, 4:8, 5:10, 6:12, 7:14} # The dequantized value = sign * half_step / 2.0 = sign * (idx_to_half_step[mag] / 2.0) idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14], dtype=torch.float32, device='cuda') half_steps = idx_to_half_step[mag_indices.long()] # (M, N) # Dequantized: value = sign * half_step / 2.0 x_deq_fp32 = signs * half_steps / 2.0 # (M, N) # Apply block scale and global scale # block_scale shape: (M, N//16) → expand to (M, N) block_scale_expanded = block_scale.repeat_interleave(block_size, dim=-1).float() x_deq_fp32 = x_deq_fp32 * block_scale_expanded * global_scale return x_deq_fp32.to(torch.bfloat16) def test(): print("=== NVFP4-1.1: BF16→FP4 Quantization Kernel Development ===") test_nvfp4_quantize_correctness() if __name__ == '__main__': test()