191 lines
7.5 KiB
Python
191 lines
7.5 KiB
Python
"""
|
||
NVFP4-1.1 Step 1: Standalone BF16→FP4 quantization kernel.
|
||
|
||
This kernel reads BF16 input from GMEM, quantizes to NVFP4 format
|
||
(per-16-element FP8 E4M3 scale + FP4 packed data), and writes to GMEM.
|
||
|
||
This is the quantization logic that will be fused into the SwiGLU epilogue
|
||
once verified correct.
|
||
|
||
NVFP4 quantization:
|
||
1. For each 16-element microblock: compute amax
|
||
2. block_scale = (amax / 6.0).to(float8_e4m3fn) [FP8 E4M3]
|
||
3. For each element: x_scaled = x / block_scale
|
||
4. half_steps = round(|x_scaled| * 2).clamp(0, 12) [13 quantization levels]
|
||
5. Lookup FP4 index from half_steps
|
||
6. Pack pairs of FP4 nibbles into bytes (float4_e2m1fn_x2)
|
||
|
||
Reference: dsv4/ops/quantize.py (quantize_activation_nvfp4)
|
||
|
||
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quantize_kernel.py
|
||
"""
|
||
import torch
|
||
import math
|
||
import cutlass
|
||
import cutlass.cute as cute
|
||
import cutlass.utils as utils
|
||
from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr
|
||
from cutlass.utils import LayoutEnum
|
||
import cuda.bindings.driver as cuda
|
||
import cutlass.torch as ct
|
||
|
||
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
||
|
||
|
||
class Fp4QuantizeKernel:
|
||
"""Standalone BF16→FP4 quantization kernel (CuTeDSL).
|
||
|
||
Reads BF16 from GMEM, quantizes to NVFP4 (FP4 data + FP8 E4M3 scales),
|
||
writes both to GMEM.
|
||
"""
|
||
def __init__(self, M, N, block_size=16, global_scale=1.0):
|
||
self.M = M
|
||
self.N = N
|
||
self.block_size = block_size # 16 for NVFP4
|
||
self.global_scale = global_scale
|
||
self.n_blocks = N // block_size # number of 16-element blocks per row
|
||
|
||
@cute.jit
|
||
def __call__(self, x_bf16, x_fp4, x_sf, stream):
|
||
"""
|
||
x_bf16: (M, N) BF16 input
|
||
x_fp4: (M, N // 2) float4_e2m1fn_x2 output (packed FP4)
|
||
x_sf: (M, N // 16) float8_e4m3fn output (FP8 E4M3 scales)
|
||
"""
|
||
M = self.M; N = self.N; block_size = self.block_size
|
||
n_blocks = self.n_blocks
|
||
|
||
# Each CTA processes 1 row (for simplicity)
|
||
# Grid: (n_blocks_per_row, M, 1) — each CTA quantizes 1 block in 1 row
|
||
# Actually, let's do it differently: each CTA processes 1 row's worth of blocks
|
||
# Grid: (1, M, 1) — 1 CTA per row
|
||
|
||
# For now, let's use a simple 1-CTA approach (1 row at a time)
|
||
# and verify correctness before optimizing.
|
||
|
||
# Actually, this needs to be a proper kernel. Let me think about
|
||
# the thread mapping.
|
||
#
|
||
# For a (M, N) input with N=4096, block_size=16:
|
||
# - n_blocks = 256 per row
|
||
# - Each block: 16 BF16 elements → 8 FP4 bytes + 1 FP8 scale byte
|
||
# - Total FP4 output: M × 2048 bytes
|
||
# - Total SF output: M × 256 bytes
|
||
#
|
||
# Thread mapping: 128 threads per CTA, each thread handles 2 blocks
|
||
# (2 × 16 = 32 elements). With N=4096, n_blocks=256, need 128 threads × 2 = 256 blocks.
|
||
|
||
# For now, use the PyTorch reference as the ground truth and just verify.
|
||
# The kernel implementation will be added in the next step.
|
||
pass
|
||
|
||
|
||
def test_nvfp4_quantize_correctness():
|
||
"""Verify that our Python NVFP4 quantization matches the reference."""
|
||
print("\n=== NVFP4 Quantization Correctness Test ===")
|
||
|
||
torch.manual_seed(42)
|
||
M, N = 128, 512
|
||
|
||
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
||
global_scale = 1.0
|
||
|
||
x_fp4, block_scale = quantize_activation_nvfp4(x, global_scale)
|
||
|
||
print(f" Input: ({M}, {N}) BF16")
|
||
print(f" FP4 output: {x_fp4.shape}, dtype={x_fp4.dtype}")
|
||
print(f" Scale output: {block_scale.shape}, dtype={block_scale.dtype}")
|
||
print(f" SF_VEC_SIZE: {SF_VEC_SIZE}")
|
||
|
||
# Dequantize to verify round-trip
|
||
# FP4 → BF16: x_bf16 ≈ x_fp4 * block_scale * global_scale
|
||
# But we don't have a dequantize function yet.
|
||
# Let's just verify the shapes and dtypes are correct.
|
||
|
||
assert x_fp4.dtype == torch.float4_e2m1fn_x2, f"Expected float4_e2m1fn_x2, got {x_fp4.dtype}"
|
||
assert block_scale.dtype == torch.float8_e4m3fn, f"Expected float8_e4m3fn, got {block_scale.dtype}"
|
||
assert x_fp4.shape == (M, N // 2), f"Expected ({M}, {N//2}), got {x_fp4.shape}"
|
||
assert block_scale.shape == (M, N // SF_VEC_SIZE), f"Expected ({M}, {N//SF_VEC_SIZE}), got {block_scale.shape}"
|
||
|
||
print(" ✅ Shapes and dtypes correct")
|
||
|
||
# Verify that the quantized values round-trip correctly
|
||
# by dequantizing and comparing
|
||
# We need to manually dequantize since there's no built-in function
|
||
x_deq = dequantize_nvfp4(x_fp4, block_scale, global_scale, N)
|
||
|
||
# Compare with original (FP4 has limited precision)
|
||
cos = torch.nn.functional.cosine_similarity(
|
||
x.flatten().unsqueeze(0).float(),
|
||
x_deq.flatten().unsqueeze(0).float()
|
||
).item()
|
||
|
||
print(f" Dequantized cosine similarity: {cos:.6f}")
|
||
print(f" FP4 round-trip: {'PASS' if cos >= 0.95 else 'FAIL (expected ~0.97-0.99 for FP4)'}")
|
||
|
||
# Also verify max relative error
|
||
mask = x.abs() > 0.01 # Skip near-zero values
|
||
rel_err = ((x[mask] - x_deq[mask]).abs() / x[mask].abs()).mean().item()
|
||
print(f" Mean relative error: {rel_err:.4f}")
|
||
|
||
|
||
def dequantize_nvfp4(x_fp4, block_scale, global_scale, N):
|
||
"""Dequantize NVFP4 back to BF16 for verification.
|
||
|
||
This is the inverse of quantize_activation_nvfp4.
|
||
"""
|
||
M = x_fp4.shape[0]
|
||
block_size = SF_VEC_SIZE
|
||
|
||
# Unpack FP4 nibbles to indices
|
||
raw = x_fp4.view(torch.uint8) # (M, N//2) bytes
|
||
even_nibbles = raw & 0x0F # Lower nibble
|
||
odd_nibbles = (raw >> 4) & 0x0F # Upper nibble
|
||
|
||
# Interleave even and odd
|
||
indices = torch.stack([even_nibbles, odd_nibbles], dim=-1).reshape(M, N)
|
||
|
||
# Extract sign and magnitude
|
||
signs = (indices >= 8).float() * -2 + 1 # +1 for 0-7, -1 for 8-15
|
||
mag_indices = indices % 8 # Magnitude index 0-7
|
||
|
||
# FP4 E2M1 values: 0, 2, 3, 4, 6, 8, 12, inf (but in our quantization,
|
||
# the step LUT maps half_steps to specific values)
|
||
# For simplicity, use the dequantization from the step LUT
|
||
# The E2M1 format: (-1)^sign × 2^exp × (1 + mantissa/2)
|
||
# exp = (mag >> 1), mantissa = mag & 1
|
||
# value = (-1)^sign × 2^exp × (1 + (mag & 1) * 0.5)
|
||
# But this doesn't match the NVFP4 quantization exactly.
|
||
# Let me use a different approach: compute from the quantize step LUT.
|
||
|
||
# Actually, let me just do: x_deq = indices_to_values * block_scale * global_scale
|
||
# But I don't have the step LUT in the kernel. Let me compute it directly.
|
||
|
||
# The step LUT in quantize.py:
|
||
# step_to_idx = {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7}
|
||
# inverse: idx_to_half_step = {0:0, 1:2, 2:4, 3:6, 4:8, 5:10, 6:12, 7:14}
|
||
# The dequantized value = sign * half_step / 2.0 = sign * (idx_to_half_step[mag] / 2.0)
|
||
|
||
idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14],
|
||
dtype=torch.float32, device='cuda')
|
||
half_steps = idx_to_half_step[mag_indices.long()] # (M, N)
|
||
|
||
# Dequantized: value = sign * half_step / 2.0
|
||
x_deq_fp32 = signs * half_steps / 2.0 # (M, N)
|
||
|
||
# Apply block scale and global scale
|
||
# block_scale shape: (M, N//16) → expand to (M, N)
|
||
block_scale_expanded = block_scale.repeat_interleave(block_size, dim=-1).float()
|
||
x_deq_fp32 = x_deq_fp32 * block_scale_expanded * global_scale
|
||
|
||
return x_deq_fp32.to(torch.bfloat16)
|
||
|
||
|
||
def test():
|
||
print("=== NVFP4-1.1: BF16→FP4 Quantization Kernel Development ===")
|
||
test_nvfp4_quantize_correctness()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
test()
|