diag: NVFP4-0 primitive verification script
This commit is contained in:
112
tests/unit/test_nvfp4_primitives.py
Normal file
112
tests/unit/test_nvfp4_primitives.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""NVFP4-0: Verify Blackwell FP4 primitives are correct.
|
||||
|
||||
Quick diagnostic — prints sf_dtype, sf_vec_size, MMA kind, TMA element type.
|
||||
Run on B200 only.
|
||||
"""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN
|
||||
|
||||
def test_nvfp4_primitives():
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM runner")
|
||||
print("=" * 60)
|
||||
|
||||
# Check what gemm_runner uses
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from dsv4.ops.gemm_runner import SF_VEC_SIZE
|
||||
print(f" gemm_runner SF_VEC_SIZE = {SF_VEC_SIZE}")
|
||||
|
||||
# Check what dense.py BlockScaledGEMM expects
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
# Print the docstring-specified types
|
||||
print(f" dense.py NVF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU/Float8E4M3FN, sf_vec_size=16")
|
||||
print(f" dense.py MXF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU, sf_vec_size=32")
|
||||
|
||||
# Construct a small NVFP4 GEMM and print the sf_dtype it receives
|
||||
M, N, K = 128, 256, 512
|
||||
# Create small FP4 tensors
|
||||
a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2)
|
||||
b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2)
|
||||
a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
|
||||
try:
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
gemm = BlockScaledGEMM(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf,
|
||||
sfb_ptr=b_sf,
|
||||
sf_vec_size=16,
|
||||
)
|
||||
print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}")
|
||||
print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}")
|
||||
print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}")
|
||||
except Exception as e:
|
||||
print(f" BlockScaledGEMM construction failed: {e}")
|
||||
|
||||
# Also try with E8M0 scale dtype to see if that's what's being used
|
||||
a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
|
||||
try:
|
||||
gemm_e8m0 = BlockScaledGEMM(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf_e8m0,
|
||||
sfb_ptr=b_sf_e8m0,
|
||||
sf_vec_size=32,
|
||||
)
|
||||
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_dtype = {gemm_e8m0.sf_dtype}")
|
||||
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
|
||||
except Exception as e:
|
||||
print(f" BlockScaledGEMM (E8M0) construction failed: {e}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.3: FP4 TMA element type in quantize.py")
|
||||
print("=" * 60)
|
||||
|
||||
from dsv4.ops.quantize import quantize_tensor_nvfp4, _create_fp4_tensor
|
||||
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
|
||||
x_fp4, x_sf = quantize_tensor_nvfp4(x)
|
||||
print(f" Input dtype: {x.dtype}")
|
||||
print(f" FP4 output dtype: {x_fp4.dtype}")
|
||||
print(f" SF output dtype: {x_sf.dtype}")
|
||||
print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])")
|
||||
print(f" SF shape: {x_sf.shape} (expected: [4, 32])")
|
||||
|
||||
# Check that float4_e2m1fn_x2 survives into TMA
|
||||
print(f" torch.float4_e2m1fn_x2 = {torch.float4_e2m1fn_x2}")
|
||||
print(f" FP4 tensor is float4_e2m1fn_x2: {x_fp4.dtype == torch.float4_e2m1fn_x2}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.4: MMA kind verification")
|
||||
print("=" * 60)
|
||||
|
||||
# Try to inspect the MMA kind that tcgen05 uses for NVFP4
|
||||
try:
|
||||
# Create a tiled MMA with FP4 inputs
|
||||
a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode()
|
||||
b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode()
|
||||
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
|
||||
Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 256),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
print(f" FP4 MMA shape_mnk = {mma.shape_mnk}")
|
||||
print(f" FP4 MMA thr_id shape = {mma.thr_id.shape}")
|
||||
print(f" FP4 MMA operand A dtype = {mma.a_layout.value_type if hasattr(mma, 'a_layout') else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f" FP4 MMA construction failed: {e}")
|
||||
|
||||
print()
|
||||
print("DONE — NVFP4-0 verification complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nvfp4_primitives()
|
||||
Reference in New Issue
Block a user