Files
nvfp4-megamoe-kernel/tests/unit/test_p4_tma_descriptor_diff.py

131 lines
4.7 KiB
Python

"""
P4: Dump TMA descriptor bytes from both CuTeDSL and cuTensorMapEncodeTiled.
1. CuTeDSL: create a TMA descriptor for a (128,16) BF16 tensor via cute.compile
and dump the 128 bytes.
2. Driver API: use cuTensorMapEncodeTiled for the same tensor and dump 128 bytes.
3. memcmp and print differences.
The CuTeDSL path already works (it's used in the existing FMHA kernel).
The raw Driver API path hangs when used with cp.async.bulk.tensor.2d.
By comparing descriptors byte-by-byte, we can identify the field that differs.
"""
import torch
import sys
import os
import struct
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def dump_driver_api_descriptor():
"""Create TMA descriptor using Driver API (cuTensorMapEncodeTiled)."""
from cuda.bindings import driver, runtime
# Initialize CUDA
runtime.cudaFree(0) # Force context creation
# Create a (128, 16) BF16 tensor on GPU
rows, cols = 128, 16
data = torch.zeros(rows, cols, dtype=torch.bfloat16, device='cuda')
# cuTensorMap descriptor: 128 bytes
tensor_map = driver.CUtensorMap()
# cuTensorMapEncodeTiled args:
# - tensorMap: output
# - tensorRank: 2
# - cudaDataType: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 (6)
# - deviceAddress: data pointer
# - tensorDims: [128, 16]
# - globalStrides: [16*2, 2] (byte strides: row_stride=16*2 bytes, col_stride=2 bytes)
# - boxDims: [16, 16] (TMA tile size)
# - elementStrides: [1, 1]
# - interleave: CU_TENSOR_MAP_INTERLEAVE_NONE (0)
# - swizzle: CU_TENSOR_MAP_SWIZZLE_NONE (0)
# - l2Promotion: CU_TENSOR_MAP_L2_PROMOTION_NONE (0)
# - oobFill: CU_TENSOR_MAP_OOB_FILL_NONE (0)
globalStrides = (ctypes.c_uint64 * 2)()
globalStrides[0] = cols * 2 # stride from row 0 to row 1 = 16 * 2 = 32 bytes
globalStrides[1] = 2 # stride from col 0 to col 1 = 2 bytes
import ctypes
tensorDims = (ctypes.c_uint32 * 2)(rows, cols)
boxDims = (ctypes.c_uint32 * 2)(16, 16)
elementStrides = (ctypes.c_uint32 * 2)(1, 1)
# Actually, let me use the cuda.bindings API directly
# cuTensorMapEncodeTiled is in cuda.bindings.driver
result = driver.cuTensorMapEncodeTiled(
tensor_map,
2, # tensorRank
driver.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
int(data.data_ptr()),
(rows, cols),
(cols * 2, 2), # globalStrides in bytes
(16, 16), # boxDims
(1, 1), # elementStrides
driver.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
driver.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
driver.CUtensorMapL2Promotion.CU_TENSOR_MAP_L2_PROMOTION_NONE,
driver.CUtensorMapOOBFill.CU_TENSOR_MAP_OOB_FILL_NONE,
)
if result != driver.CUresult.CUDA_SUCCESS:
print(f"cuTensorMapEncodeTiled failed: {result}")
return None
# The descriptor is 128 bytes. Access it via the opaque field.
# CUtensorMap has an opaque byte array
desc_bytes = bytes(tensor_map)
return desc_bytes
def dump_cutedsl_descriptor():
"""Create TMA descriptor using CuTeDSL and dump bytes.
CuTeDSL creates descriptors internally when you call cute.make_tensor
with a TMA layout. We need to intercept the descriptor bytes.
Actually, CuTeDSL's TMA descriptors are created at JIT compile time
and stored in the kernel's parameter struct. We can't easily dump them
from Python.
Alternative: use CuTe's TMA descriptor creation API directly.
cute.arch.make_tma_copy can create a descriptor that we can inspect.
"""
# This is harder than I thought. CuTeDSL hides the descriptor creation.
# Let me use a different approach: create a small CuTeDSL kernel that
# does a TMA load (which works), and use Nsight to capture the descriptor.
# Or: use the CUTLASS Python API directly.
# Actually, the simplest approach: use the CUTLASS Python bindings
# that CuTeDSL uses internally. The TMA descriptor is a Python object
# before being passed to the kernel.
pass
def main():
print("P4: TMA Descriptor Comparison")
print("=" * 60)
# Step 1: Driver API descriptor
print("\n1. Driver API (cuTensorMapEncodeTiled) descriptor:")
desc_driver = dump_driver_api_descriptor()
if desc_driver is not None:
for i in range(0, 128, 16):
hex_str = ' '.join(f'{b:02x}' for b in desc_driver[i:i+16])
print(f" [{i:3d}-{i+15:3d}]: {hex_str}")
else:
print(" FAILED to create descriptor")
print("\nNote: CuTeDSL descriptor dump requires running inside a JIT kernel.")
print("Use the CUDA test (test_p4_tma_descriptor_diff.cu) for the full comparison.")
if __name__ == "__main__":
main()