diag: NVFP4-0 primitive verification script

This commit is contained in:
2026-05-23 08:26:56 +00:00
parent 241b49b1ee
commit f6f4ce8ec2

View File

@@ -0,0 +1,112 @@
"""NVFP4-0: Verify Blackwell FP4 primitives are correct.
Quick diagnostic — prints sf_dtype, sf_vec_size, MMA kind, TMA element type.
Run on B200 only.
"""
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN
def test_nvfp4_primitives():
print("=" * 60)
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM runner")
print("=" * 60)
# Check what gemm_runner uses
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from dsv4.ops.gemm_runner import SF_VEC_SIZE
print(f" gemm_runner SF_VEC_SIZE = {SF_VEC_SIZE}")
# Check what dense.py BlockScaledGEMM expects
from dsv4.kernels.gemm.dense import BlockScaledGEMM
# Print the docstring-specified types
print(f" dense.py NVF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU/Float8E4M3FN, sf_vec_size=16")
print(f" dense.py MXF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU, sf_vec_size=32")
# Construct a small NVFP4 GEMM and print the sf_dtype it receives
M, N, K = 128, 256, 512
# Create small FP4 tensors
a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2)
b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2)
a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
try:
from dsv4.kernels.gemm.dense import BlockScaledGEMM
gemm = BlockScaledGEMM(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf,
sfb_ptr=b_sf,
sf_vec_size=16,
)
print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}")
print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}")
print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}")
except Exception as e:
print(f" BlockScaledGEMM construction failed: {e}")
# Also try with E8M0 scale dtype to see if that's what's being used
a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
try:
gemm_e8m0 = BlockScaledGEMM(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf_e8m0,
sfb_ptr=b_sf_e8m0,
sf_vec_size=32,
)
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_dtype = {gemm_e8m0.sf_dtype}")
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
except Exception as e:
print(f" BlockScaledGEMM (E8M0) construction failed: {e}")
print()
print("=" * 60)
print("NVFP4-0.3: FP4 TMA element type in quantize.py")
print("=" * 60)
from dsv4.ops.quantize import quantize_tensor_nvfp4, _create_fp4_tensor
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
x_fp4, x_sf = quantize_tensor_nvfp4(x)
print(f" Input dtype: {x.dtype}")
print(f" FP4 output dtype: {x_fp4.dtype}")
print(f" SF output dtype: {x_sf.dtype}")
print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])")
print(f" SF shape: {x_sf.shape} (expected: [4, 32])")
# Check that float4_e2m1fn_x2 survives into TMA
print(f" torch.float4_e2m1fn_x2 = {torch.float4_e2m1fn_x2}")
print(f" FP4 tensor is float4_e2m1fn_x2: {x_fp4.dtype == torch.float4_e2m1fn_x2}")
print()
print("=" * 60)
print("NVFP4-0.4: MMA kind verification")
print("=" * 60)
# Try to inspect the MMA kind that tcgen05 uses for NVFP4
try:
# Create a tiled MMA with FP4 inputs
a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode()
b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode()
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 256),
tcgen05.OperandSource.SMEM,
)
print(f" FP4 MMA shape_mnk = {mma.shape_mnk}")
print(f" FP4 MMA thr_id shape = {mma.thr_id.shape}")
print(f" FP4 MMA operand A dtype = {mma.a_layout.value_type if hasattr(mma, 'a_layout') else 'N/A'}")
except Exception as e:
print(f" FP4 MMA construction failed: {e}")
print()
print("DONE — NVFP4-0 verification complete")
if __name__ == "__main__":
test_nvfp4_primitives()