Files
nvfp4-megamoe-kernel/tests/unit/test_nvfp4_gpu_quantize.py

190 lines
7.5 KiB
Python

"""
NVFP4 GPU Quantize Kernel Test.
Tests:
1. quantize_nvfp4_gpu: BF16 → FP4 (no deinterleave, GPU-only, no CPU sync)
2. deinterleave_quantize_nvfp4_cuda: interleaved BF16 → FP4 (deinterleave + quantize)
3. Integration: Both kernels match the Python quantize_activation_nvfp4 reference
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_gpu_quantize.py
"""
import torch
import math
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_nvfp4_gpu,
deinterleave_quantize_nvfp4_cuda,
SF_VEC_SIZE,
)
from dsv4.ops.layouts import interleave_l1_weights
def test_quantize_nvfp4_gpu_basic():
"""Test GPU-only quantize kernel against Python reference."""
print("\n=== Test 1: quantize_nvfp4_gpu basic correctness ===")
torch.manual_seed(42)
M, N = 128, 512
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
global_scale = 1.0
# Python reference
ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale)
# GPU kernel
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
# Compare shapes and dtypes
assert gpu_fp4.shape == ref_fp4.shape, f"FP4 shape mismatch: {gpu_fp4.shape} vs {ref_fp4.shape}"
assert gpu_sf.shape == ref_sf.shape, f"SF shape mismatch: {gpu_sf.shape} vs {ref_sf.shape}"
assert gpu_fp4.dtype == torch.float4_e2m1fn_x2, f"FP4 dtype: {gpu_fp4.dtype}"
assert gpu_sf.dtype == torch.float8_e4m3fn, f"SF dtype: {gpu_sf.dtype}"
# Byte-exact comparison of FP4 data
ref_bytes = ref_fp4.view(torch.uint8)
gpu_bytes = gpu_fp4.view(torch.uint8)
fp4_match = (ref_bytes == gpu_bytes).float().mean().item()
print(f" FP4 byte match: {fp4_match*100:.1f}%")
# SF comparison
ref_sf_bytes = ref_sf.view(torch.uint8)
gpu_sf_bytes = gpu_sf.view(torch.uint8)
sf_match = (ref_sf_bytes == gpu_sf_bytes).float().mean().item()
print(f" SF byte match: {sf_match*100:.1f}%")
# Round-trip cosine similarity
ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, N)
gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, N)
cos_ref = torch.nn.functional.cosine_similarity(
x.flatten().float().unsqueeze(0), ref_deq.flatten().float().unsqueeze(0)
).item()
cos_gpu = torch.nn.functional.cosine_similarity(
x.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
).item()
cos_cross = torch.nn.functional.cosine_similarity(
ref_deq.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
).item()
print(f" Python round-trip cos: {cos_ref:.6f}")
print(f" GPU round-trip cos: {cos_gpu:.6f}")
print(f" Python vs GPU cos: {cos_cross:.6f}")
assert cos_gpu >= 0.95, f"GPU round-trip cosine too low: {cos_gpu}"
assert cos_cross >= 0.99, f"Python vs GPU cosine too low: {cos_cross}"
print(f" ✅ PASS")
def test_quantize_nvfp4_gpu_larger():
"""Test GPU quantize with larger dimensions (MoE intermediate size)."""
print("\n=== Test 2: quantize_nvfp4_gpu larger shape ===")
torch.manual_seed(42)
M, N = 64, 4096
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
# Use quantize_to_nvfp4 to get a proper global_scale (from data amax)
from dsv4.ops.quantize import quantize_to_nvfp4
_, _, global_scale = quantize_to_nvfp4(x)
ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale)
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
# Byte-exact comparison
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
print(f" FP4 byte match: {fp4_match*100:.1f}%")
print(f" SF byte match: {sf_match*100:.1f}%")
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
print(f" ✅ PASS")
def test_quantize_nvfp4_gpu_no_cpu_sync():
"""Verify quantize_nvfp4_gpu has no CPU-GPU sync points."""
print("\n=== Test 3: No CPU-GPU sync verification ===")
torch.manual_seed(42)
M, N = 32, 512
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
global_scale = 1.0
# This should NOT trigger a CUDA synchronization
# If it does, cudagraph capture would fail
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
# Force a sync to check results
result = gpu_fp4.view(torch.uint8).sum().item()
print(f" FP4 sum (verification): {result}")
print(f" ✅ PASS (no crash = no CPU sync in kernel path)")
def test_deinterleave_quantize_correctness():
"""Test deinterleave + quantize CUDA kernel against Python reference."""
print("\n=== Test 4: deinterleave_quantize_nvfp4_cuda correctness ===")
torch.manual_seed(42)
M = 64
intermediate = 512
N = 2 * intermediate # fused output has silu(gate) + swiglu interleaved
# Create SwiGLU-style interleaved data matching fused kernel output
# The fused kernel outputs [silu(gate)*8, silu(gate)*up*8, ...] interleaved
gate = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
up = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
gate_silu = torch.nn.functional.silu(gate)
swiglu_result = gate_silu * up
# Create interleaved layout: [silu(gate)*8, swiglu*8, ...]
fused = torch.cat([gate_silu, swiglu_result], dim=-1).unsqueeze(0) # (1, M, N)
fused = interleave_l1_weights(fused)[0] # (M, N) interleaved
global_scale = 1.0
# Python reference: quantize the swiglu_result (odd groups)
ref_fp4, ref_sf = quantize_activation_nvfp4(swiglu_result, global_scale)
# CUDA kernel: deinterleave (extract odd groups = swiglu) + quantize in one pass
gpu_fp4, gpu_sf = deinterleave_quantize_nvfp4_cuda(fused, intermediate, global_scale)
# Byte-exact comparison
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
print(f" FP4 byte match: {fp4_match*100:.1f}%")
print(f" SF byte match: {sf_match*100:.1f}%")
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
print(f" ✅ PASS")
def test():
print("=== NVFP4 GPU Quantize Kernel Tests ===")
test_quantize_nvfp4_gpu_basic()
test_quantize_nvfp4_gpu_larger()
test_quantize_nvfp4_gpu_no_cpu_sync()
test_deinterleave_quantize_correctness()
print("\n=== ALL TESTS PASSED ===")
def _dequantize_nvfp4(x_fp4, block_scale, global_scale, N):
"""Dequantize NVFP4 back to BF16 for verification."""
M = x_fp4.shape[0]
block_size = SF_VEC_SIZE
raw = x_fp4.view(torch.uint8)
even = raw & 0x0F
odd = (raw >> 4) & 0x0F
indices = torch.stack([even, odd], dim=-1).reshape(M, N)
signs = (indices >= 8).float() * -2 + 1
mag = indices % 8
idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14],
dtype=torch.float32, device='cuda')
half_steps = idx_to_half_step[mag.long()]
x_deq_fp32 = signs * half_steps / 2.0
block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float()
x_deq_fp32 = x_deq_fp32 * block_scale_exp * global_scale
return x_deq_fp32.to(torch.bfloat16)
if __name__ == '__main__':
test()