170 lines
6.9 KiB
Python
170 lines
6.9 KiB
Python
"""NVFP4-0: Verify Blackwell FP4 primitives are correct.
|
|
|
|
Quick diagnostic — prints sf_dtype, sf_vec_size, MMA kind, TMA element type.
|
|
Run on B200 only.
|
|
"""
|
|
import torch
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN
|
|
import sys, os
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
|
|
def test_nvfp4_primitives():
|
|
print("=" * 60)
|
|
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM kernels")
|
|
print("=" * 60)
|
|
|
|
from dsv4.ops.quantize import SF_VEC_SIZE
|
|
print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}")
|
|
assert SF_VEC_SIZE == 16
|
|
print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)")
|
|
|
|
# Construct dense GEMM with the real API
|
|
from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel
|
|
try:
|
|
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
|
|
sf_vec_size=16,
|
|
mma_tiler_mn=(128, 256),
|
|
cluster_shape_mn=(1, 1),
|
|
)
|
|
print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}")
|
|
print(f" DenseGEMM.cta_group = {gemm.cta_group}")
|
|
# sf_dtype is set in _setup_attributes from the actual tensor pointers
|
|
# Print what we can here
|
|
print(f" DenseGEMM.acc_dtype = {gemm.acc_dtype}")
|
|
print(f" DenseGEMM.use_2cta_instrs = {gemm.use_2cta_instrs}")
|
|
except Exception as e:
|
|
print(f" DenseGEMM construction failed: {e}")
|
|
|
|
# Try sf_vec_size=32 (MXFP4) for comparison
|
|
try:
|
|
gemm_mxf4 = Sm100BlockScaledPersistentDenseGemmKernel(
|
|
sf_vec_size=32,
|
|
mma_tiler_mn=(128, 256),
|
|
cluster_shape_mn=(1, 1),
|
|
)
|
|
print(f" DenseGEMM (MXF4/vs=32).sf_vec_size = {gemm_mxf4.sf_vec_size}")
|
|
except Exception as e:
|
|
print(f" DenseGEMM (MXF4) failed: {e}")
|
|
|
|
# Check grouped kernel
|
|
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
|
|
try:
|
|
grp = ScaledGroupedGemmKernel(
|
|
sf_vec_size=16,
|
|
mma_tiler_mn=(128, 256),
|
|
cluster_shape_mn=(1, 1),
|
|
)
|
|
print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}")
|
|
except Exception as e:
|
|
print(f" GroupedGEMM construction failed: {e}")
|
|
|
|
# Check fused SwiGLU kernel
|
|
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
|
try:
|
|
fused = FusedSwiGLUScaledGroupedGemmKernel(
|
|
sf_vec_size=16,
|
|
mma_tiler_mn=(128, 256),
|
|
cluster_shape_mn=(1, 1),
|
|
swiglu_limit=10.0,
|
|
)
|
|
print(f" FusedSwiGLU.sf_vec_size = {fused.sf_vec_size}")
|
|
print(f" FusedSwiGLU.swiglu_limit = {fused.swiglu_limit}")
|
|
except Exception as e:
|
|
print(f" FusedSwiGLU construction failed: {e}")
|
|
|
|
# Now check sf_dtype by actually creating the tensors and calling the kernel
|
|
# sf_dtype is inferred from the sfa_ptr dtype at call time
|
|
print()
|
|
print(" --- sf_dtype inference from tensor dtypes ---")
|
|
M, N, K = 128, 256, 512
|
|
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
|
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
|
a_sf_e4m3 = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
|
b_sf_e4m3 = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
|
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
|
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
|
|
|
print(f" a_fp4 dtype = {a_fp4.dtype} (expect float4_e2m1fn_x2)")
|
|
print(f" a_sf (E4M3) dtype = {a_sf_e4m3.dtype}, shape = {a_sf_e4m3.shape}")
|
|
print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}")
|
|
|
|
# The runner uses E4M3 or E8M0? Check what quantize actually produces
|
|
from dsv4.ops.quantize import quantize_activation_nvfp4
|
|
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
|
|
global_scale = torch.tensor(448.0, device='cuda', dtype=torch.float32)
|
|
x_fp4, x_sf = quantize_activation_nvfp4(x, global_scale)
|
|
print(f" quantize_activation_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}")
|
|
if x_sf.dtype == torch.float8_e4m3fn:
|
|
print(f" ✅ Scale factors are E4M3 — NVFP4 correct")
|
|
elif x_sf.dtype == torch.float8_e8m0fnu:
|
|
print(f" ⚠️ Scale factors are E8M0 — this is MXFP4 format, NOT NVFP4!")
|
|
else:
|
|
print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!")
|
|
|
|
# Check the gemm_runner class name
|
|
import dsv4.ops.gemm_runner as gr
|
|
runner_classes = [name for name in dir(gr) if 'unner' in name or 'Gemm' in name]
|
|
print(f" gemm_runner classes: {runner_classes}")
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print("NVFP4-0.3: FP4 TMA element type")
|
|
print("=" * 60)
|
|
|
|
print(f" FP4 tensor dtype = {x_fp4.dtype}")
|
|
print(f" FP4 tensor shape = {x_fp4.shape}")
|
|
print(f" torch.float4_e2m1fn_x2 available = {hasattr(torch, 'float4_e2m1fn_x2')}")
|
|
if x_fp4.dtype == torch.float4_e2m1fn_x2:
|
|
print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA")
|
|
else:
|
|
print(f" ❌ FP4 tensor dtype is {x_fp4.dtype} — should be float4_e2m1fn_x2!")
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print("NVFP4-0.4: MMA kind verification")
|
|
print("=" * 60)
|
|
|
|
# Test with correct LayoutEnum
|
|
from cutlass.utils import LayoutEnum
|
|
try:
|
|
a_major = LayoutEnum.ROW_MAJOR.mma_major_mode()
|
|
b_major = LayoutEnum.COLUMN_MAJOR.mma_major_mode()
|
|
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
|
|
Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32,
|
|
tcgen05.CtaGroup.ONE, (128, 256),
|
|
tcgen05.OperandSource.SMEM,
|
|
)
|
|
print(f" FP4 MMA shape_mnk = {mma.shape_mnk}")
|
|
print(f" ✅ FP4 MMA construction succeeded")
|
|
except Exception as e:
|
|
print(f" FP4 MMA construction failed: {e}")
|
|
|
|
# Check use_2cta_instrs behavior
|
|
print()
|
|
print(" --- use_2cta_instrs check ---")
|
|
try:
|
|
gemm_1cta = Sm100BlockScaledPersistentDenseGemmKernel(
|
|
sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1),
|
|
)
|
|
print(f" mma_tiler_mn=(128,256) → use_2cta_instrs={gemm_1cta.use_2cta_instrs}")
|
|
except Exception as e:
|
|
print(f" 1-CTA failed: {e}")
|
|
try:
|
|
gemm_2cta = Sm100BlockScaledPersistentDenseGemmKernel(
|
|
sf_vec_size=16, mma_tiler_mn=(256, 256), cluster_shape_mn=(2, 1),
|
|
)
|
|
print(f" mma_tiler_mn=(256,256) → use_2cta_instrs={gemm_2cta.use_2cta_instrs}")
|
|
except Exception as e:
|
|
print(f" 2-CTA failed: {e}")
|
|
|
|
print()
|
|
print("DONE — NVFP4-0 verification complete")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_nvfp4_primitives()
|