131 lines
4.7 KiB
Python
131 lines
4.7 KiB
Python
"""
|
|
P4: Dump TMA descriptor bytes from both CuTeDSL and cuTensorMapEncodeTiled.
|
|
|
|
1. CuTeDSL: create a TMA descriptor for a (128,16) BF16 tensor via cute.compile
|
|
and dump the 128 bytes.
|
|
2. Driver API: use cuTensorMapEncodeTiled for the same tensor and dump 128 bytes.
|
|
3. memcmp and print differences.
|
|
|
|
The CuTeDSL path already works (it's used in the existing FMHA kernel).
|
|
The raw Driver API path hangs when used with cp.async.bulk.tensor.2d.
|
|
By comparing descriptors byte-by-byte, we can identify the field that differs.
|
|
"""
|
|
import torch
|
|
import sys
|
|
import os
|
|
import struct
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
def dump_driver_api_descriptor():
|
|
"""Create TMA descriptor using Driver API (cuTensorMapEncodeTiled)."""
|
|
from cuda.bindings import driver, runtime
|
|
|
|
# Initialize CUDA
|
|
runtime.cudaFree(0) # Force context creation
|
|
|
|
# Create a (128, 16) BF16 tensor on GPU
|
|
rows, cols = 128, 16
|
|
data = torch.zeros(rows, cols, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# cuTensorMap descriptor: 128 bytes
|
|
tensor_map = driver.CUtensorMap()
|
|
|
|
# cuTensorMapEncodeTiled args:
|
|
# - tensorMap: output
|
|
# - tensorRank: 2
|
|
# - cudaDataType: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 (6)
|
|
# - deviceAddress: data pointer
|
|
# - tensorDims: [128, 16]
|
|
# - globalStrides: [16*2, 2] (byte strides: row_stride=16*2 bytes, col_stride=2 bytes)
|
|
# - boxDims: [16, 16] (TMA tile size)
|
|
# - elementStrides: [1, 1]
|
|
# - interleave: CU_TENSOR_MAP_INTERLEAVE_NONE (0)
|
|
# - swizzle: CU_TENSOR_MAP_SWIZZLE_NONE (0)
|
|
# - l2Promotion: CU_TENSOR_MAP_L2_PROMOTION_NONE (0)
|
|
# - oobFill: CU_TENSOR_MAP_OOB_FILL_NONE (0)
|
|
|
|
globalStrides = (ctypes.c_uint64 * 2)()
|
|
globalStrides[0] = cols * 2 # stride from row 0 to row 1 = 16 * 2 = 32 bytes
|
|
globalStrides[1] = 2 # stride from col 0 to col 1 = 2 bytes
|
|
|
|
import ctypes
|
|
tensorDims = (ctypes.c_uint32 * 2)(rows, cols)
|
|
boxDims = (ctypes.c_uint32 * 2)(16, 16)
|
|
elementStrides = (ctypes.c_uint32 * 2)(1, 1)
|
|
|
|
# Actually, let me use the cuda.bindings API directly
|
|
# cuTensorMapEncodeTiled is in cuda.bindings.driver
|
|
|
|
result = driver.cuTensorMapEncodeTiled(
|
|
tensor_map,
|
|
2, # tensorRank
|
|
driver.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
|
int(data.data_ptr()),
|
|
(rows, cols),
|
|
(cols * 2, 2), # globalStrides in bytes
|
|
(16, 16), # boxDims
|
|
(1, 1), # elementStrides
|
|
driver.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
|
driver.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
|
|
driver.CUtensorMapL2Promotion.CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
|
driver.CUtensorMapOOBFill.CU_TENSOR_MAP_OOB_FILL_NONE,
|
|
)
|
|
|
|
if result != driver.CUresult.CUDA_SUCCESS:
|
|
print(f"cuTensorMapEncodeTiled failed: {result}")
|
|
return None
|
|
|
|
# The descriptor is 128 bytes. Access it via the opaque field.
|
|
# CUtensorMap has an opaque byte array
|
|
desc_bytes = bytes(tensor_map)
|
|
return desc_bytes
|
|
|
|
|
|
def dump_cutedsl_descriptor():
|
|
"""Create TMA descriptor using CuTeDSL and dump bytes.
|
|
|
|
CuTeDSL creates descriptors internally when you call cute.make_tensor
|
|
with a TMA layout. We need to intercept the descriptor bytes.
|
|
|
|
Actually, CuTeDSL's TMA descriptors are created at JIT compile time
|
|
and stored in the kernel's parameter struct. We can't easily dump them
|
|
from Python.
|
|
|
|
Alternative: use CuTe's TMA descriptor creation API directly.
|
|
cute.arch.make_tma_copy can create a descriptor that we can inspect.
|
|
"""
|
|
# This is harder than I thought. CuTeDSL hides the descriptor creation.
|
|
# Let me use a different approach: create a small CuTeDSL kernel that
|
|
# does a TMA load (which works), and use Nsight to capture the descriptor.
|
|
# Or: use the CUTLASS Python API directly.
|
|
|
|
# Actually, the simplest approach: use the CUTLASS Python bindings
|
|
# that CuTeDSL uses internally. The TMA descriptor is a Python object
|
|
# before being passed to the kernel.
|
|
pass
|
|
|
|
|
|
def main():
|
|
print("P4: TMA Descriptor Comparison")
|
|
print("=" * 60)
|
|
|
|
# Step 1: Driver API descriptor
|
|
print("\n1. Driver API (cuTensorMapEncodeTiled) descriptor:")
|
|
desc_driver = dump_driver_api_descriptor()
|
|
if desc_driver is not None:
|
|
for i in range(0, 128, 16):
|
|
hex_str = ' '.join(f'{b:02x}' for b in desc_driver[i:i+16])
|
|
print(f" [{i:3d}-{i+15:3d}]: {hex_str}")
|
|
else:
|
|
print(" FAILED to create descriptor")
|
|
|
|
print("\nNote: CuTeDSL descriptor dump requires running inside a JIT kernel.")
|
|
print("Use the CUDA test (test_p4_tma_descriptor_diff.cu) for the full comparison.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|