test: simple (128,16) TMA desc for K sub-tile only

This commit is contained in:
2026-05-29 18:45:01 +00:00
parent eaf8a878cf
commit d64b62bc80

View File

@@ -74,7 +74,7 @@ fmha_tma_konly_kernel(
// ===== QK GEMM =====
{
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT; kt++) {
for (int kt = 0; kt < 1; kt++) { // Only 1 K sub-tile for now
// Load Q sub-tile: direct from GMEM (T=1, only row 0)
for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0;
for (int d = tid; d < MMA_K_BF16; d += 128) {
@@ -158,10 +158,16 @@ int main() {
cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
// TMA descriptor for K
// Try simple TMA: (128, 16) descriptor for just the first K sub-tile
// If this works, the issue is with (128, 64) descriptors
bf16_t* d_k_sub;
cudaMalloc(&d_k_sub, SK * MMA_K_BF16 * sizeof(bf16_t));
// Copy first sub-tile of K
cudaMemcpy(d_k_sub, d_k, SK * MMA_K_BF16 * sizeof(bf16_t), cudaMemcpyDeviceToDevice);
CUtensorMap tma_k;
CUtensorMap* d_tma_k;
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, BLOCK_MN, MMA_K_BF16)) {
if (!create_tma_desc_2d_bf16(&tma_k, d_k_sub, SK, (uint64_t)MMA_K_BF16, BLOCK_MN, MMA_K_BF16)) {
printf("TMA K desc FAILED\n"); return 1;
}
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));