P4: fix API signature rank/dtype order, OOB_FILL defines
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
/**
|
||||
* P4: Dump TMA descriptor bytes for comparison.
|
||||
* CUDA 13.2 compatible — uses correct API signature.
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
@@ -7,6 +8,33 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
/* CUDA 13.2 cuTensorMapEncodeTiled signature:
|
||||
* CUresult cuTensorMapEncodeTiled(
|
||||
* CUtensorMap *tensorMap,
|
||||
* cuuint32_t tensorRank,
|
||||
* CUtensorMapDataType dataType,
|
||||
* void *globalAddress,
|
||||
* const cuuint64_t *tensorDims,
|
||||
* const cuuint64_t *globalStrides,
|
||||
* const cuuint32_t *boxDims,
|
||||
* const cuuint32_t *elementStrides,
|
||||
* CUtensorMapInterleave interleave,
|
||||
* CUtensorMapSwizzle swizzle,
|
||||
* CUtensorMapL2promotion l2Promotion,
|
||||
* CUtensorMapOOBfill oobFill
|
||||
* );
|
||||
*
|
||||
* Note: OOB fill is CUtensorMapOOBfill (lowercase f) in CUDA 13.2
|
||||
*/
|
||||
|
||||
// Define missing enum values if needed
|
||||
#ifndef CU_TENSOR_MAP_OOB_FILL_NONE
|
||||
#define CU_TENSOR_MAP_OOB_FILL_NONE ((CUtensorMapOOBfill)0)
|
||||
#endif
|
||||
#ifndef CU_TENSOR_MAP_OOB_FILL_ZERO
|
||||
#define CU_TENSOR_MAP_OOB_FILL_ZERO ((CUtensorMapOOBfill)1)
|
||||
#endif
|
||||
|
||||
int main() {
|
||||
const int ROWS = 128;
|
||||
const int COLS = 16;
|
||||
@@ -24,77 +52,47 @@ int main() {
|
||||
CUtensorMap tma_desc;
|
||||
CUresult res;
|
||||
|
||||
// Config 1: NO swizzle
|
||||
printf("=== Descriptor 1: NO swizzle ===\n");
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
auto dump_desc = [](const char* label, const CUtensorMap& desc) {
|
||||
printf("=== %s ===\n", label);
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&desc);
|
||||
for (int i = 0; i < 128; i += 16) {
|
||||
printf("[%3d-%3d]: ", i, i+15);
|
||||
for (int j = 0; j < 16; j++) printf("%02x ", b[i+j]);
|
||||
printf("\n");
|
||||
}
|
||||
};
|
||||
|
||||
// 1: NO swizzle
|
||||
res = cuTensorMapEncodeTiled(&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
d_ptr, tensorDims, globalStrides, boxDims, elementStrides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_NONE
|
||||
);
|
||||
if (res != CUDA_SUCCESS) { printf("FAILED: %d\n", res); }
|
||||
else {
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&tma_desc);
|
||||
for (int i = 0; i < 128; i += 16) {
|
||||
printf("[%3d-%3d]: ", i, i+15);
|
||||
for (int j = 0; j < 16; j++) printf("%02x ", b[i+j]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_NONE);
|
||||
if (res == CUDA_SUCCESS) dump_desc("NO swizzle", tma_desc);
|
||||
else printf("=== NO swizzle: FAILED (%d) ===\n", res);
|
||||
|
||||
// Config 2: SWIZZLE_128B
|
||||
printf("\n=== Descriptor 2: SWIZZLE_128B ===\n");
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
// 2: SWIZZLE_128B
|
||||
res = cuTensorMapEncodeTiled(&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
d_ptr, tensorDims, globalStrides, boxDims, elementStrides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_NONE
|
||||
);
|
||||
if (res != CUDA_SUCCESS) { printf("FAILED: %d\n", res); }
|
||||
else {
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&tma_desc);
|
||||
for (int i = 0; i < 128; i += 16) {
|
||||
printf("[%3d-%3d]: ", i, i+15);
|
||||
for (int j = 0; j < 16; j++) printf("%02x ", b[i+j]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_NONE);
|
||||
if (res == CUDA_SUCCESS) dump_desc("SWIZZLE_128B", tma_desc);
|
||||
else printf("=== SWIZZLE_128B: FAILED (%d) ===\n", res);
|
||||
|
||||
// Config 3: NO swizzle, OOB_FILL_ZERO
|
||||
printf("\n=== Descriptor 3: NO swizzle, OOB_FILL_ZERO ===\n");
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
// 3: NO swizzle, OOB_FILL_ZERO
|
||||
res = cuTensorMapEncodeTiled(&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
d_ptr, tensorDims, globalStrides, boxDims, elementStrides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_ZERO
|
||||
);
|
||||
if (res != CUDA_SUCCESS) { printf("FAILED: %d\n", res); }
|
||||
else {
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&tma_desc);
|
||||
for (int i = 0; i < 128; i += 16) {
|
||||
printf("[%3d-%3d]: ", i, i+15);
|
||||
for (int j = 0; j < 16; j++) printf("%02x ", b[i+j]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_ZERO);
|
||||
if (res == CUDA_SUCCESS) dump_desc("NO swizzle + OOB_FILL_ZERO", tma_desc);
|
||||
else printf("=== NO swizzle + OOB_FILL_ZERO: FAILED (%d) ===\n", res);
|
||||
|
||||
// Config 4: SWIZZLE_128B, OOB_FILL_ZERO
|
||||
printf("\n=== Descriptor 4: SWIZZLE_128B, OOB_FILL_ZERO ===\n");
|
||||
res = cuTensorMapEncodeTiled(
|
||||
&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
// 4: SWIZZLE_128B, OOB_FILL_ZERO
|
||||
res = cuTensorMapEncodeTiled(&tma_desc, 2, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
d_ptr, tensorDims, globalStrides, boxDims, elementStrides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_ZERO
|
||||
);
|
||||
if (res != CUDA_SUCCESS) { printf("FAILED: %d\n", res); }
|
||||
else {
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&tma_desc);
|
||||
for (int i = 0; i < 128; i += 16) {
|
||||
printf("[%3d-%3d]: ", i, i+15);
|
||||
for (int j = 0; j < 16; j++) printf("%02x ", b[i+j]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_OOB_FILL_ZERO);
|
||||
if (res == CUDA_SUCCESS) dump_desc("SWIZZLE_128B + OOB_FILL_ZERO", tma_desc);
|
||||
else printf("=== SWIZZLE_128B + OOB_FILL_ZERO: FAILED (%d) ===\n", res);
|
||||
|
||||
cudaFree(d_ptr);
|
||||
printf("\nPASSED\n");
|
||||
|
||||
Reference in New Issue
Block a user