diff --git a/tests/unit/test_fmha_tma.cu b/tests/unit/test_fmha_tma.cu index 43a4b3d8..6c05340b 100644 --- a/tests/unit/test_fmha_tma.cu +++ b/tests/unit/test_fmha_tma.cu @@ -159,8 +159,7 @@ struct TmaDescSet { // The data in GMEM starts at d_q, shape (T, HD), stride (HD, 1). // We treat it as (128, HD) — rows beyond T are garbage, kernel ignores them. uint32_t q_tile_rows = 128; - CUresult q_res = create_tma_desc_2d_bf16(&tma_q, d_q, 128, (uint64_t)hd, q_tile_rows, (uint32_t)hd); - if (q_res != true) { + if (!create_tma_desc_2d_bf16(&tma_q, d_q, 128, (uint64_t)hd, q_tile_rows, (uint32_t*)hd)) { printf(" Failed to create Q TMA desc: rows=128, cols=%d, tile_rows=128, tile_cols=%d\n", hd, hd); return false; }