190 lines
7.5 KiB
Python
190 lines
7.5 KiB
Python
"""
|
|
NVFP4 GPU Quantize Kernel Test.
|
|
|
|
Tests:
|
|
1. quantize_nvfp4_gpu: BF16 → FP4 (no deinterleave, GPU-only, no CPU sync)
|
|
2. deinterleave_quantize_nvfp4_cuda: interleaved BF16 → FP4 (deinterleave + quantize)
|
|
3. Integration: Both kernels match the Python quantize_activation_nvfp4 reference
|
|
|
|
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_gpu_quantize.py
|
|
"""
|
|
import torch
|
|
import math
|
|
import sys, os
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
from dsv4.ops.quantize import (
|
|
quantize_activation_nvfp4,
|
|
quantize_nvfp4_gpu,
|
|
deinterleave_quantize_nvfp4_cuda,
|
|
SF_VEC_SIZE,
|
|
)
|
|
from dsv4.ops.layouts import interleave_l1_weights
|
|
|
|
|
|
def test_quantize_nvfp4_gpu_basic():
|
|
"""Test GPU-only quantize kernel against Python reference."""
|
|
print("\n=== Test 1: quantize_nvfp4_gpu basic correctness ===")
|
|
torch.manual_seed(42)
|
|
M, N = 128, 512
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
|
global_scale = 1.0
|
|
|
|
# Python reference
|
|
ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale)
|
|
|
|
# GPU kernel
|
|
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
|
|
|
|
# Compare shapes and dtypes
|
|
assert gpu_fp4.shape == ref_fp4.shape, f"FP4 shape mismatch: {gpu_fp4.shape} vs {ref_fp4.shape}"
|
|
assert gpu_sf.shape == ref_sf.shape, f"SF shape mismatch: {gpu_sf.shape} vs {ref_sf.shape}"
|
|
assert gpu_fp4.dtype == torch.float4_e2m1fn_x2, f"FP4 dtype: {gpu_fp4.dtype}"
|
|
assert gpu_sf.dtype == torch.float8_e4m3fn, f"SF dtype: {gpu_sf.dtype}"
|
|
|
|
# Byte-exact comparison of FP4 data
|
|
ref_bytes = ref_fp4.view(torch.uint8)
|
|
gpu_bytes = gpu_fp4.view(torch.uint8)
|
|
fp4_match = (ref_bytes == gpu_bytes).float().mean().item()
|
|
print(f" FP4 byte match: {fp4_match*100:.1f}%")
|
|
|
|
# SF comparison
|
|
ref_sf_bytes = ref_sf.view(torch.uint8)
|
|
gpu_sf_bytes = gpu_sf.view(torch.uint8)
|
|
sf_match = (ref_sf_bytes == gpu_sf_bytes).float().mean().item()
|
|
print(f" SF byte match: {sf_match*100:.1f}%")
|
|
|
|
# Round-trip cosine similarity
|
|
ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, N)
|
|
gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, N)
|
|
|
|
cos_ref = torch.nn.functional.cosine_similarity(
|
|
x.flatten().float().unsqueeze(0), ref_deq.flatten().float().unsqueeze(0)
|
|
).item()
|
|
cos_gpu = torch.nn.functional.cosine_similarity(
|
|
x.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
|
|
).item()
|
|
cos_cross = torch.nn.functional.cosine_similarity(
|
|
ref_deq.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
|
|
).item()
|
|
|
|
print(f" Python round-trip cos: {cos_ref:.6f}")
|
|
print(f" GPU round-trip cos: {cos_gpu:.6f}")
|
|
print(f" Python vs GPU cos: {cos_cross:.6f}")
|
|
|
|
assert cos_gpu >= 0.95, f"GPU round-trip cosine too low: {cos_gpu}"
|
|
assert cos_cross >= 0.99, f"Python vs GPU cosine too low: {cos_cross}"
|
|
print(f" ✅ PASS")
|
|
|
|
|
|
def test_quantize_nvfp4_gpu_larger():
|
|
"""Test GPU quantize with larger dimensions (MoE intermediate size)."""
|
|
print("\n=== Test 2: quantize_nvfp4_gpu larger shape ===")
|
|
torch.manual_seed(42)
|
|
M, N = 64, 4096
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
|
# Use quantize_to_nvfp4 to get a proper global_scale (from data amax)
|
|
from dsv4.ops.quantize import quantize_to_nvfp4
|
|
_, _, global_scale = quantize_to_nvfp4(x)
|
|
|
|
ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale)
|
|
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
|
|
|
|
# Byte-exact comparison
|
|
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
|
|
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
|
|
print(f" FP4 byte match: {fp4_match*100:.1f}%")
|
|
print(f" SF byte match: {sf_match*100:.1f}%")
|
|
|
|
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
|
|
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
|
|
print(f" ✅ PASS")
|
|
|
|
|
|
def test_quantize_nvfp4_gpu_no_cpu_sync():
|
|
"""Verify quantize_nvfp4_gpu has no CPU-GPU sync points."""
|
|
print("\n=== Test 3: No CPU-GPU sync verification ===")
|
|
torch.manual_seed(42)
|
|
M, N = 32, 512
|
|
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
|
global_scale = 1.0
|
|
|
|
# This should NOT trigger a CUDA synchronization
|
|
# If it does, cudagraph capture would fail
|
|
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
|
|
|
|
# Force a sync to check results
|
|
result = gpu_fp4.view(torch.uint8).sum().item()
|
|
print(f" FP4 sum (verification): {result}")
|
|
print(f" ✅ PASS (no crash = no CPU sync in kernel path)")
|
|
|
|
|
|
def test_deinterleave_quantize_correctness():
|
|
"""Test deinterleave + quantize CUDA kernel against Python reference."""
|
|
print("\n=== Test 4: deinterleave_quantize_nvfp4_cuda correctness ===")
|
|
torch.manual_seed(42)
|
|
M = 64
|
|
intermediate = 512
|
|
N = 2 * intermediate # fused output has silu(gate) + swiglu interleaved
|
|
|
|
# Create SwiGLU-style interleaved data matching fused kernel output
|
|
# The fused kernel outputs [silu(gate)*8, silu(gate)*up*8, ...] interleaved
|
|
gate = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
|
|
up = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
|
|
gate_silu = torch.nn.functional.silu(gate)
|
|
swiglu_result = gate_silu * up
|
|
|
|
# Create interleaved layout: [silu(gate)*8, swiglu*8, ...]
|
|
fused = torch.cat([gate_silu, swiglu_result], dim=-1).unsqueeze(0) # (1, M, N)
|
|
fused = interleave_l1_weights(fused)[0] # (M, N) interleaved
|
|
|
|
global_scale = 1.0
|
|
|
|
# Python reference: quantize the swiglu_result (odd groups)
|
|
ref_fp4, ref_sf = quantize_activation_nvfp4(swiglu_result, global_scale)
|
|
|
|
# CUDA kernel: deinterleave (extract odd groups = swiglu) + quantize in one pass
|
|
gpu_fp4, gpu_sf = deinterleave_quantize_nvfp4_cuda(fused, intermediate, global_scale)
|
|
|
|
# Byte-exact comparison
|
|
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
|
|
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
|
|
print(f" FP4 byte match: {fp4_match*100:.1f}%")
|
|
print(f" SF byte match: {sf_match*100:.1f}%")
|
|
|
|
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
|
|
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
|
|
print(f" ✅ PASS")
|
|
|
|
|
|
def test():
|
|
print("=== NVFP4 GPU Quantize Kernel Tests ===")
|
|
test_quantize_nvfp4_gpu_basic()
|
|
test_quantize_nvfp4_gpu_larger()
|
|
test_quantize_nvfp4_gpu_no_cpu_sync()
|
|
test_deinterleave_quantize_correctness()
|
|
print("\n=== ALL TESTS PASSED ===")
|
|
|
|
|
|
def _dequantize_nvfp4(x_fp4, block_scale, global_scale, N):
|
|
"""Dequantize NVFP4 back to BF16 for verification."""
|
|
M = x_fp4.shape[0]
|
|
block_size = SF_VEC_SIZE
|
|
raw = x_fp4.view(torch.uint8)
|
|
even = raw & 0x0F
|
|
odd = (raw >> 4) & 0x0F
|
|
indices = torch.stack([even, odd], dim=-1).reshape(M, N)
|
|
signs = (indices >= 8).float() * -2 + 1
|
|
mag = indices % 8
|
|
idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14],
|
|
dtype=torch.float32, device='cuda')
|
|
half_steps = idx_to_half_step[mag.long()]
|
|
x_deq_fp32 = signs * half_steps / 2.0
|
|
block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float()
|
|
x_deq_fp32 = x_deq_fp32 * block_scale_exp * global_scale
|
|
return x_deq_fp32.to(torch.bfloat16)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|