NVFP4-1.1: add FP4 quantize round-trip test (step 1 of kernel fusion)

This commit is contained in:
2026-05-25 03:15:40 +00:00
parent eb46e4d15e
commit 6dac3bcaf0

View File

@@ -0,0 +1,190 @@
"""
NVFP4-1.1 Step 1: Standalone BF16→FP4 quantization kernel.
This kernel reads BF16 input from GMEM, quantizes to NVFP4 format
(per-16-element FP8 E4M3 scale + FP4 packed data), and writes to GMEM.
This is the quantization logic that will be fused into the SwiGLU epilogue
once verified correct.
NVFP4 quantization:
1. For each 16-element microblock: compute amax
2. block_scale = (amax / 6.0).to(float8_e4m3fn) [FP8 E4M3]
3. For each element: x_scaled = x / block_scale
4. half_steps = round(|x_scaled| * 2).clamp(0, 12) [13 quantization levels]
5. Lookup FP4 index from half_steps
6. Pack pairs of FP4 nibbles into bytes (float4_e2m1fn_x2)
Reference: dsv4/ops/quantize.py (quantize_activation_nvfp4)
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quantize_kernel.py
"""
import torch
import math
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
class Fp4QuantizeKernel:
"""Standalone BF16→FP4 quantization kernel (CuTeDSL).
Reads BF16 from GMEM, quantizes to NVFP4 (FP4 data + FP8 E4M3 scales),
writes both to GMEM.
"""
def __init__(self, M, N, block_size=16, global_scale=1.0):
self.M = M
self.N = N
self.block_size = block_size # 16 for NVFP4
self.global_scale = global_scale
self.n_blocks = N // block_size # number of 16-element blocks per row
@cute.jit
def __call__(self, x_bf16, x_fp4, x_sf, stream):
"""
x_bf16: (M, N) BF16 input
x_fp4: (M, N // 2) float4_e2m1fn_x2 output (packed FP4)
x_sf: (M, N // 16) float8_e4m3fn output (FP8 E4M3 scales)
"""
M = self.M; N = self.N; block_size = self.block_size
n_blocks = self.n_blocks
# Each CTA processes 1 row (for simplicity)
# Grid: (n_blocks_per_row, M, 1) — each CTA quantizes 1 block in 1 row
# Actually, let's do it differently: each CTA processes 1 row's worth of blocks
# Grid: (1, M, 1) — 1 CTA per row
# For now, let's use a simple 1-CTA approach (1 row at a time)
# and verify correctness before optimizing.
# Actually, this needs to be a proper kernel. Let me think about
# the thread mapping.
#
# For a (M, N) input with N=4096, block_size=16:
# - n_blocks = 256 per row
# - Each block: 16 BF16 elements → 8 FP4 bytes + 1 FP8 scale byte
# - Total FP4 output: M × 2048 bytes
# - Total SF output: M × 256 bytes
#
# Thread mapping: 128 threads per CTA, each thread handles 2 blocks
# (2 × 16 = 32 elements). With N=4096, n_blocks=256, need 128 threads × 2 = 256 blocks.
# For now, use the PyTorch reference as the ground truth and just verify.
# The kernel implementation will be added in the next step.
pass
def test_nvfp4_quantize_correctness():
"""Verify that our Python NVFP4 quantization matches the reference."""
print("\n=== NVFP4 Quantization Correctness Test ===")
torch.manual_seed(42)
M, N = 128, 512
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
global_scale = 1.0
x_fp4, block_scale, _ = quantize_activation_nvfp4(x, global_scale)
print(f" Input: ({M}, {N}) BF16")
print(f" FP4 output: {x_fp4.shape}, dtype={x_fp4.dtype}")
print(f" Scale output: {block_scale.shape}, dtype={block_scale.dtype}")
print(f" SF_VEC_SIZE: {SF_VEC_SIZE}")
# Dequantize to verify round-trip
# FP4 → BF16: x_bf16 ≈ x_fp4 * block_scale * global_scale
# But we don't have a dequantize function yet.
# Let's just verify the shapes and dtypes are correct.
assert x_fp4.dtype == torch.float4_e2m1fn_x2, f"Expected float4_e2m1fn_x2, got {x_fp4.dtype}"
assert block_scale.dtype == torch.float8_e4m3fn, f"Expected float8_e4m3fn, got {block_scale.dtype}"
assert x_fp4.shape == (M, N // 2), f"Expected ({M}, {N//2}), got {x_fp4.shape}"
assert block_scale.shape == (M, N // SF_VEC_SIZE), f"Expected ({M}, {N//SF_VEC_SIZE}), got {block_scale.shape}"
print(" ✅ Shapes and dtypes correct")
# Verify that the quantized values round-trip correctly
# by dequantizing and comparing
# We need to manually dequantize since there's no built-in function
x_deq = dequantize_nvfp4(x_fp4, block_scale, global_scale, N)
# Compare with original (FP4 has limited precision)
cos = torch.nn.functional.cosine_similarity(
x.flatten().unsqueeze(0).float(),
x_deq.flatten().unsqueeze(0).float()
).item()
print(f" Dequantized cosine similarity: {cos:.6f}")
print(f" FP4 round-trip: {'PASS' if cos >= 0.95 else 'FAIL (expected ~0.97-0.99 for FP4)'}")
# Also verify max relative error
mask = x.abs() > 0.01 # Skip near-zero values
rel_err = ((x[mask] - x_deq[mask]).abs() / x[mask].abs()).mean().item()
print(f" Mean relative error: {rel_err:.4f}")
def dequantize_nvfp4(x_fp4, block_scale, global_scale, N):
"""Dequantize NVFP4 back to BF16 for verification.
This is the inverse of quantize_activation_nvfp4.
"""
M = x_fp4.shape[0]
block_size = SF_VEC_SIZE
# Unpack FP4 nibbles to indices
raw = x_fp4.view(torch.uint8) # (M, N//2) bytes
even_nibbles = raw & 0x0F # Lower nibble
odd_nibbles = (raw >> 4) & 0x0F # Upper nibble
# Interleave even and odd
indices = torch.stack([even_nibbles, odd_nibbles], dim=-1).reshape(M, N)
# Extract sign and magnitude
signs = (indices >= 8).float() * -2 + 1 # +1 for 0-7, -1 for 8-15
mag_indices = indices % 8 # Magnitude index 0-7
# FP4 E2M1 values: 0, 2, 3, 4, 6, 8, 12, inf (but in our quantization,
# the step LUT maps half_steps to specific values)
# For simplicity, use the dequantization from the step LUT
# The E2M1 format: (-1)^sign × 2^exp × (1 + mantissa/2)
# exp = (mag >> 1), mantissa = mag & 1
# value = (-1)^sign × 2^exp × (1 + (mag & 1) * 0.5)
# But this doesn't match the NVFP4 quantization exactly.
# Let me use a different approach: compute from the quantize step LUT.
# Actually, let me just do: x_deq = indices_to_values * block_scale * global_scale
# But I don't have the step LUT in the kernel. Let me compute it directly.
# The step LUT in quantize.py:
# step_to_idx = {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7}
# inverse: idx_to_half_step = {0:0, 1:2, 2:4, 3:6, 4:8, 5:10, 6:12, 7:14}
# The dequantized value = sign * half_step / 2.0 = sign * (idx_to_half_step[mag] / 2.0)
idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14],
dtype=torch.float32, device='cuda')
half_steps = idx_to_half_step[mag_indices.long()] # (M, N)
# Dequantized: value = sign * half_step / 2.0
x_deq_fp32 = signs * half_steps / 2.0 # (M, N)
# Apply block scale and global scale
# block_scale shape: (M, N//16) → expand to (M, N)
block_scale_expanded = block_scale.repeat_interleave(block_size, dim=-1).float()
x_deq_fp32 = x_deq_fp32 * block_scale_expanded * global_scale
return x_deq_fp32.to(torch.bfloat16)
def test():
print("=== NVFP4-1.1: BF16→FP4 Quantization Kernel Development ===")
test_nvfp4_quantize_correctness()
if __name__ == '__main__':
test()