Files
nvfp4-megamoe-kernel/tests/unit/test_nvfp4_quant_kernel.py

177 lines
7.1 KiB
Python

"""
NVFP4-1.1: BF16→FP4 quantization kernel (CuTeDSL, Blackwell SM100).
Reads BF16 from GMEM, quantizes to NVFP4, writes FP4 + FP8 scales to GMEM.
Uses TMA for efficient GMEM access.
Grid: (num_rows, 1, 1) — 1 CTA per row.
Each CTA processes one row, with 128 threads each handling multiple 16-element blocks.
Step 2 of the SwiGLU FP4 fusion.
Step 1: ✅ Python round-trip (cos 0.981)
Step 2: THIS — Standalone kernel
Step 3: Fuse into SwiGLU epilogue
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quant_kernel.py
"""
import torch
import math
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr
import cuda.bindings.driver as cuda
import cutlass.torch as ct
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
class Nvfp4QuantizeKernel:
def __init__(self, M, N, block_size=16):
self.M = M
self.N = N
self.block_size = block_size
@cute.jit
def __call__(self, x_bf16, x_sf, stream):
"""
x_bf16: (M, N) BF16 input — also used as FP4 output (in-place, same memory)
x_sf: (M, N // 16) FP8 E4M3 scale factors
"""
M = self.M; N = self.N; bs = self.block_size
self._kernel(x_bf16, x_sf, Int32(M), Int32(N), Int32(bs)).launch(
grid=(M, 1, 1), block=[128, 1, 1], stream=stream
)
@cute.kernel
def _kernel(self, x_bf16, x_sf, M, N, block_size):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
row = bidx
n_blocks = N // block_size # number of 16-element blocks per row
threads = Int32(128)
blocks_per_thread = n_blocks // threads # blocks handled by each thread
# Each thread processes blocks_per_thread consecutive blocks
for b in range(blocks_per_thread):
block_idx = tidx * blocks_per_thread + b
col_start = block_idx * block_size
# Step 1: Read 16 BF16 elements and compute amax
amax = Float32(0.0)
vals = [None] * 16 # Will store BF16 values for later quantization
for i in range(block_size):
# Direct GMEM read (not TMA — simpler for first implementation)
val = x_bf16[row, col_start + i]
abs_val = val * val # val^2 — we need |val|
# Actually, we need max(|val|). Let me use a simpler approach.
# CuTeDSL doesn't have abs() as a primitive.
# Use: abs_val = val if val > 0 else -val
abs_val = val if val > Float32(0.0) else Float32(0.0) - val
amax = amax if amax > abs_val else abs_val
# Step 2: Compute FP8 E4M3 scale = (amax / 6.0)
# For now, store as FP32 (FP8 cast is complex in CuTeDSL)
scale = amax / Float32(6.0) if amax > Float32(0.0) else Float32(1.0)
x_sf[row, block_idx] = scale
# Step 3: Quantize each BF16 element to FP4 and pack
packed = Int32(0)
for i in range(block_size):
val = x_bf16[row, col_start + i]
# Scale
scaled = val / scale
# Abs
abs_scaled = scaled if scaled > Float32(0.0) else Float32(0.0) - scaled
# Half-step: round(|scaled| * 2)
half_step_raw = abs_scaled * Float32(2.0)
# Round: floor(x + 0.5)
half_step = half_step_raw + Float32(0.5)
# Clamp to [0, 12]
half_step = half_step if half_step > Float32(0.0) else Float32(0.0)
half_step = half_step if half_step < Float32(12.0) else Float32(12.0)
# Convert to int and map to FP4 index
hs_int = Int32(half_step)
# LUT: {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7}
# half_step is already quantized to even values 0,2,...,12
fp4_idx = hs_int // Int32(2)
fp4_idx = fp4_idx if fp4_idx < Int32(7) else Int32(6)
# Sign
sign = Int32(1) if val < Float32(0.0) else Int32(0)
nibble = fp4_idx | (sign << Int32(3))
# Pack: even elements in lower nibble, odd in upper
if i % 2 == 0:
packed = nibble
else:
packed = packed | (nibble << Int32(4))
# Write the packed byte
# For float4_e2m1fn_x2 output, we'd use a proper TMA store
# For now, this is the quantization logic verification
# Store FP4 packed data (simplified — not using TMA yet)
# This would need a proper GMEM write path
def dequantize_nvfp4_simple(x_fp4, block_scale, global_scale, N):
"""Simple dequantize for verification."""
M = x_fp4.shape[0]
block_size = SF_VEC_SIZE
raw = x_fp4.view(torch.uint8)
even = raw & 0x0F
odd = (raw >> 4) & 0x0F
indices = torch.stack([even, odd], dim=-1).reshape(M, N)
signs = (indices >= 8).float() * -2 + 1
mag = indices % 8
idx_to_val = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0],
dtype=torch.float32, device='cuda')
vals = idx_to_val[mag.long()]
x_deq = signs * vals
block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float()
x_deq = x_deq * block_scale_exp * global_scale
return x_deq.to(torch.bfloat16)
def test_nvfp4_python():
"""Verify Python NVFP4 quantization round-trip."""
print("\n=== NVFP4 Python Round-Trip ===")
torch.manual_seed(42)
M, N = 128, 512
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
x_fp4, sf = quantize_activation_nvfp4(x, 1.0)
x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N)
cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item()
print(f" Round-trip cos: {cos:.6f} ({'PASS' if cos >= 0.95 else 'FAIL'})")
assert cos >= 0.95, f"Round-trip cosine too low: {cos}"
def test_nvfp4_kernel_launch():
"""Test the CuTeDSL quantization kernel (basic launch, not full quantization yet)."""
print("\n=== NVFP4 Kernel Launch Test ===")
print(" (Kernel implementation in progress — CuTeDSL quantization needs TMA + FP4 packing)")
print(" Current status: quantization logic designed, need to add TMA store for FP4 output")
# For now, verify that the Python quantization matches our dequantize
torch.manual_seed(42)
M, N = 4, 64
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
x_fp4, sf = quantize_activation_nvfp4(x, 1.0)
x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N)
cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item()
print(f" Small test cos: {cos:.6f}")
def test():
print("=== NVFP4-1.1: BF16→FP4 Quantization ===")
test_nvfp4_python()
test_nvfp4_kernel_launch()
if __name__ == '__main__':
test()