diff --git a/tests/unit/test_tma_konly.cu b/tests/unit/test_tma_konly.cu index 04a9735f..c77d48f8 100644 --- a/tests/unit/test_tma_konly.cu +++ b/tests/unit/test_tma_konly.cu @@ -74,7 +74,7 @@ fmha_tma_konly_kernel( // ===== QK GEMM ===== { uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN); - for (int kt = 0; kt < NKT; kt++) { + for (int kt = 0; kt < 1; kt++) { // Only 1 K sub-tile for now // Load Q sub-tile: direct from GMEM (T=1, only row 0) for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0; for (int d = tid; d < MMA_K_BF16; d += 128) { @@ -158,10 +158,16 @@ int main() { cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); - // TMA descriptor for K + // Try simple TMA: (128, 16) descriptor for just the first K sub-tile + // If this works, the issue is with (128, 64) descriptors + bf16_t* d_k_sub; + cudaMalloc(&d_k_sub, SK * MMA_K_BF16 * sizeof(bf16_t)); + // Copy first sub-tile of K + cudaMemcpy(d_k_sub, d_k, SK * MMA_K_BF16 * sizeof(bf16_t), cudaMemcpyDeviceToDevice); + CUtensorMap tma_k; CUtensorMap* d_tma_k; - if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, BLOCK_MN, MMA_K_BF16)) { + if (!create_tma_desc_2d_bf16(&tma_k, d_k_sub, SK, (uint64_t)MMA_K_BF16, BLOCK_MN, MMA_K_BF16)) { printf("TMA K desc FAILED\n"); return 1; } cudaMalloc(&d_tma_k, sizeof(CUtensorMap));