177 lines
7.1 KiB
Python
177 lines
7.1 KiB
Python
"""
|
|
NVFP4-1.1: BF16→FP4 quantization kernel (CuTeDSL, Blackwell SM100).
|
|
|
|
Reads BF16 from GMEM, quantizes to NVFP4, writes FP4 + FP8 scales to GMEM.
|
|
Uses TMA for efficient GMEM access.
|
|
|
|
Grid: (num_rows, 1, 1) — 1 CTA per row.
|
|
Each CTA processes one row, with 128 threads each handling multiple 16-element blocks.
|
|
|
|
Step 2 of the SwiGLU FP4 fusion.
|
|
Step 1: ✅ Python round-trip (cos 0.981)
|
|
Step 2: THIS — Standalone kernel
|
|
Step 3: Fuse into SwiGLU epilogue
|
|
|
|
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_quant_kernel.py
|
|
"""
|
|
import torch
|
|
import math
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.utils as utils
|
|
from cutlass import Float32, BFloat16, Float8E4M3FN, Int32, const_expr
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
|
|
from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE
|
|
|
|
|
|
class Nvfp4QuantizeKernel:
|
|
def __init__(self, M, N, block_size=16):
|
|
self.M = M
|
|
self.N = N
|
|
self.block_size = block_size
|
|
|
|
@cute.jit
|
|
def __call__(self, x_bf16, x_sf, stream):
|
|
"""
|
|
x_bf16: (M, N) BF16 input — also used as FP4 output (in-place, same memory)
|
|
x_sf: (M, N // 16) FP8 E4M3 scale factors
|
|
"""
|
|
M = self.M; N = self.N; bs = self.block_size
|
|
self._kernel(x_bf16, x_sf, Int32(M), Int32(N), Int32(bs)).launch(
|
|
grid=(M, 1, 1), block=[128, 1, 1], stream=stream
|
|
)
|
|
|
|
@cute.kernel
|
|
def _kernel(self, x_bf16, x_sf, M, N, block_size):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
row = bidx
|
|
|
|
n_blocks = N // block_size # number of 16-element blocks per row
|
|
threads = Int32(128)
|
|
blocks_per_thread = n_blocks // threads # blocks handled by each thread
|
|
|
|
# Each thread processes blocks_per_thread consecutive blocks
|
|
for b in range(blocks_per_thread):
|
|
block_idx = tidx * blocks_per_thread + b
|
|
col_start = block_idx * block_size
|
|
|
|
# Step 1: Read 16 BF16 elements and compute amax
|
|
amax = Float32(0.0)
|
|
vals = [None] * 16 # Will store BF16 values for later quantization
|
|
for i in range(block_size):
|
|
# Direct GMEM read (not TMA — simpler for first implementation)
|
|
val = x_bf16[row, col_start + i]
|
|
abs_val = val * val # val^2 — we need |val|
|
|
# Actually, we need max(|val|). Let me use a simpler approach.
|
|
# CuTeDSL doesn't have abs() as a primitive.
|
|
# Use: abs_val = val if val > 0 else -val
|
|
abs_val = val if val > Float32(0.0) else Float32(0.0) - val
|
|
amax = amax if amax > abs_val else abs_val
|
|
|
|
# Step 2: Compute FP8 E4M3 scale = (amax / 6.0)
|
|
# For now, store as FP32 (FP8 cast is complex in CuTeDSL)
|
|
scale = amax / Float32(6.0) if amax > Float32(0.0) else Float32(1.0)
|
|
x_sf[row, block_idx] = scale
|
|
|
|
# Step 3: Quantize each BF16 element to FP4 and pack
|
|
packed = Int32(0)
|
|
for i in range(block_size):
|
|
val = x_bf16[row, col_start + i]
|
|
# Scale
|
|
scaled = val / scale
|
|
# Abs
|
|
abs_scaled = scaled if scaled > Float32(0.0) else Float32(0.0) - scaled
|
|
# Half-step: round(|scaled| * 2)
|
|
half_step_raw = abs_scaled * Float32(2.0)
|
|
# Round: floor(x + 0.5)
|
|
half_step = half_step_raw + Float32(0.5)
|
|
# Clamp to [0, 12]
|
|
half_step = half_step if half_step > Float32(0.0) else Float32(0.0)
|
|
half_step = half_step if half_step < Float32(12.0) else Float32(12.0)
|
|
# Convert to int and map to FP4 index
|
|
hs_int = Int32(half_step)
|
|
# LUT: {0:0, 2:1, 4:2, 6:3, 8:4, 10:5, 12:6, 14:7}
|
|
# half_step is already quantized to even values 0,2,...,12
|
|
fp4_idx = hs_int // Int32(2)
|
|
fp4_idx = fp4_idx if fp4_idx < Int32(7) else Int32(6)
|
|
|
|
# Sign
|
|
sign = Int32(1) if val < Float32(0.0) else Int32(0)
|
|
nibble = fp4_idx | (sign << Int32(3))
|
|
|
|
# Pack: even elements in lower nibble, odd in upper
|
|
if i % 2 == 0:
|
|
packed = nibble
|
|
else:
|
|
packed = packed | (nibble << Int32(4))
|
|
# Write the packed byte
|
|
# For float4_e2m1fn_x2 output, we'd use a proper TMA store
|
|
# For now, this is the quantization logic verification
|
|
|
|
# Store FP4 packed data (simplified — not using TMA yet)
|
|
# This would need a proper GMEM write path
|
|
|
|
|
|
def dequantize_nvfp4_simple(x_fp4, block_scale, global_scale, N):
|
|
"""Simple dequantize for verification."""
|
|
M = x_fp4.shape[0]
|
|
block_size = SF_VEC_SIZE
|
|
|
|
raw = x_fp4.view(torch.uint8)
|
|
even = raw & 0x0F
|
|
odd = (raw >> 4) & 0x0F
|
|
indices = torch.stack([even, odd], dim=-1).reshape(M, N)
|
|
signs = (indices >= 8).float() * -2 + 1
|
|
mag = indices % 8
|
|
|
|
idx_to_val = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0],
|
|
dtype=torch.float32, device='cuda')
|
|
vals = idx_to_val[mag.long()]
|
|
x_deq = signs * vals
|
|
|
|
block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float()
|
|
x_deq = x_deq * block_scale_exp * global_scale
|
|
return x_deq.to(torch.bfloat16)
|
|
|
|
|
|
def test_nvfp4_python():
|
|
"""Verify Python NVFP4 quantization round-trip."""
|
|
print("\n=== NVFP4 Python Round-Trip ===")
|
|
torch.manual_seed(42)
|
|
M, N = 128, 512
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
|
x_fp4, sf = quantize_activation_nvfp4(x, 1.0)
|
|
x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N)
|
|
cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item()
|
|
print(f" Round-trip cos: {cos:.6f} ({'PASS' if cos >= 0.95 else 'FAIL'})")
|
|
assert cos >= 0.95, f"Round-trip cosine too low: {cos}"
|
|
|
|
|
|
def test_nvfp4_kernel_launch():
|
|
"""Test the CuTeDSL quantization kernel (basic launch, not full quantization yet)."""
|
|
print("\n=== NVFP4 Kernel Launch Test ===")
|
|
print(" (Kernel implementation in progress — CuTeDSL quantization needs TMA + FP4 packing)")
|
|
print(" Current status: quantization logic designed, need to add TMA store for FP4 output")
|
|
|
|
# For now, verify that the Python quantization matches our dequantize
|
|
torch.manual_seed(42)
|
|
M, N = 4, 64
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
|
x_fp4, sf = quantize_activation_nvfp4(x, 1.0)
|
|
x_deq = dequantize_nvfp4_simple(x_fp4, sf, 1.0, N)
|
|
cos = torch.nn.functional.cosine_similarity(x.flatten().float().unsqueeze(0), x_deq.flatten().float().unsqueeze(0)).item()
|
|
print(f" Small test cos: {cos:.6f}")
|
|
|
|
|
|
def test():
|
|
print("=== NVFP4-1.1: BF16→FP4 Quantization ===")
|
|
test_nvfp4_python()
|
|
test_nvfp4_kernel_launch()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|