Files
nvfp4-megamoe-kernel/tests/unit/test_nvfp4_quantize_kernel.py

191 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
NVFP4-1.1 Step 1: Standalone BF16→FP4 quantization kernel.
This kernel reads BF16 input from GMEM, quantizes to NVFP4 format
(per-16-element FP8 E4M3 scale + FP4 packed data), and writes to GMEM.
This is the quantization logic that will be fused into the SwiGLU epilogue
once verified correct.
NVFP4 quantization:
1. For each 16-element microblock: compute amax
2. block_scale = (amax / 6.0).to(float8_e4m3fn) [FP8 E4M3]
3. For each element: x_scaled = x / block_scale
4. half_steps = round(|x_scaled| * 2).clamp(0, 12) [13 quantization levels]
5. Lookup FP4 index from half_steps
6. Pack pairs of FP4 nibbles into bytes (float4_e2m1fn_x2)
Reference: dsv4/ops/quantize.py (quantize_activation_nvfp4)
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quantize_kernel.py
"""
import torch
import math
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
class Fp4QuantizeKernel:
"""Standalone BF16→FP4 quantization kernel (CuTeDSL).
Reads BF16 from GMEM, quantizes to NVFP4 (FP4 data + FP8 E4M3 scales),
writes both to GMEM.
"""
def __init__(self, M, N, block_size=16, global_scale=1.0):
self.M = M
self.N = N
self.block_size = block_size # 16 for NVFP4
self.global_scale = global_scale
self.n_blocks = N // block_size # number of 16-element blocks per row
@cute.jit
def __call__(self, x_bf16, x_fp4, x_sf, stream):
"""
x_bf16: (M, N) BF16 input
x_fp4: (M, N // 2) float4_e2m1fn_x2 output (packed FP4)
x_sf: (M, N // 16) float8_e4m3fn output (FP8 E4M3 scales)
"""
M = self.M; N = self.N; block_size = self.block_size
n_blocks = self.n_blocks
# Each CTA processes 1 row (for simplicity)
# Grid: (n_blocks_per_row, M, 1) — each CTA quantizes 1 block in 1 row
# Actually, let's do it differently: each CTA processes 1 row's worth of blocks
# Grid: (1, M, 1) — 1 CTA per row
# For now, let's use a simple 1-CTA approach (1 row at a time)
# and verify correctness before optimizing.
# Actually, this needs to be a proper kernel. Let me think about
# the thread mapping.
#
# For a (M, N) input with N=4096, block_size=16:
# - n_blocks = 256 per row
# - Each block: 16 BF16 elements → 8 FP4 bytes + 1 FP8 scale byte
# - Total FP4 output: M × 2048 bytes
# - Total SF output: M × 256 bytes
#
# Thread mapping: 128 threads per CTA, each thread handles 2 blocks
# (2 × 16 = 32 elements). With N=4096, n_blocks=256, need 128 threads × 2 = 256 blocks.
# For now, use the PyTorch reference as the ground truth and just verify.
# The kernel implementation will be added in the next step.
pass
def test_nvfp4_quantize_correctness():
"""Verify that our Python NVFP4 quantization matches the reference."""
print("\n=== NVFP4 Quantization Correctness Test ===")
torch.manual_seed(42)
M, N = 128, 512
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
global_scale = 1.0
x_fp4, block_scale = quantize_activation_nvfp4(x, global_scale)
print(f" Input: ({M}, {N}) BF16")
print(f" FP4 output: {x_fp4.shape}, dtype={x_fp4.dtype}")
print(f" Scale output: {block_scale.shape}, dtype={block_scale.dtype}")
print(f" SF_VEC_SIZE: {SF_VEC_SIZE}")
# Dequantize to verify round-trip
# FP4 → BF16: x_bf16 ≈ x_fp4 * block_scale * global_scale
# But we don't have a dequantize function yet.
# Let's just verify the shapes and dtypes are correct.
assert x_fp4.dtype == torch.float4_e2m1fn_x2, f"Expected float4_e2m1fn_x2, got {x_fp4.dtype}"
assert block_scale.dtype == torch.float8_e4m3fn, f"Expected float8_e4m3fn, got {block_scale.dtype}"
assert x_fp4.shape == (M, N // 2), f"Expected ({M}, {N//2}), got {x_fp4.shape}"
assert block_scale.shape == (M, N // SF_VEC_SIZE), f"Expected ({M}, {N//SF_VEC_SIZE}), got {block_scale.shape}"
print(" ✅ Shapes and dtypes correct")
# Verify that the quantized values round-trip correctly
# by dequantizing and comparing
# We need to manually dequantize since there's no built-in function
x_deq = dequantize_nvfp4(x_fp4, block_scale, global_scale, N)
# Compare with original (FP4 has limited precision)
cos = torch.nn.functional.cosine_similarity(
x.flatten().unsqueeze(0).float(),
x_deq.flatten().unsqueeze(0).float()
).item()
print(f" Dequantized cosine similarity: {cos:.6f}")
print(f" FP4 round-trip: {'PASS' if cos >= 0.95 else 'FAIL (expected ~0.97-0.99 for FP4)'}")
# Also verify max relative error
mask = x.abs() > 0.01 # Skip near-zero values
rel_err = ((x[mask] - x_deq[mask]).abs() / x[mask].abs()).mean().item()
print(f" Mean relative error: {rel_err:.4f}")
def dequantize_nvfp4(x_fp4, block_scale, global_scale, N):
"""Dequantize NVFP4 back to BF16 for verification.
This is the inverse of quantize_activation_nvfp4.
"""
M = x_fp4.shape[0]
block_size = SF_VEC_SIZE
# Unpack FP4 nibbles to indices
raw = x_fp4.view(torch.uint8) # (M, N//2) bytes
even_nibbles = raw & 0x0F # Lower nibble
odd_nibbles = (raw >> 4) & 0x0F # Upper nibble
# Interleave even and odd
indices = torch.stack([even_nibbles, odd_nibbles], dim=-1).reshape(M, N)
# Extract sign and magnitude
signs = (indices >= 8).float() * -2 + 1 # +1 for 0-7, -1 for 8-15
mag_indices = indices % 8 # Magnitude index 0-7
# FP4 E2M1 values: 0, 2, 3, 4, 6, 8, 12, inf (but in our quantization,
# the step LUT maps half_steps to specific values)
# For simplicity, use the dequantization from the step LUT
# The E2M1 format: (-1)^sign × 2^exp × (1 + mantissa/2)
# exp = (mag >> 1), mantissa = mag & 1
# value = (-1)^sign × 2^exp × (1 + (mag & 1) * 0.5)
# But this doesn't match the NVFP4 quantization exactly.
# Let me use a different approach: compute from the quantize step LUT.
# Actually, let me just do: x_deq = indices_to_values * block_scale * global_scale
# But I don't have the step LUT in the kernel. Let me compute it directly.
# The step LUT in quantize.py:
# step_to_idx = {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7}
# inverse: idx_to_half_step = {0:0, 1:2, 2:4, 3:6, 4:8, 5:10, 6:12, 7:14}
# The dequantized value = sign * half_step / 2.0 = sign * (idx_to_half_step[mag] / 2.0)
idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14],
dtype=torch.float32, device='cuda')
half_steps = idx_to_half_step[mag_indices.long()] # (M, N)
# Dequantized: value = sign * half_step / 2.0
x_deq_fp32 = signs * half_steps / 2.0 # (M, N)
# Apply block scale and global scale
# block_scale shape: (M, N//16) → expand to (M, N)
block_scale_expanded = block_scale.repeat_interleave(block_size, dim=-1).float()
x_deq_fp32 = x_deq_fp32 * block_scale_expanded * global_scale
return x_deq_fp32.to(torch.bfloat16)
def test():
print("=== NVFP4-1.1: BF16→FP4 Quantization Kernel Development ===")
test_nvfp4_quantize_correctness()
if __name__ == '__main__':
test()