test: simple (128,16) TMA desc for K sub-tile only
This commit is contained in:
@@ -74,7 +74,7 @@ fmha_tma_konly_kernel(
|
||||
// ===== QK GEMM =====
|
||||
{
|
||||
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
|
||||
for (int kt = 0; kt < NKT; kt++) {
|
||||
for (int kt = 0; kt < 1; kt++) { // Only 1 K sub-tile for now
|
||||
// Load Q sub-tile: direct from GMEM (T=1, only row 0)
|
||||
for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0;
|
||||
for (int d = tid; d < MMA_K_BF16; d += 128) {
|
||||
@@ -158,10 +158,16 @@ int main() {
|
||||
cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// TMA descriptor for K
|
||||
// Try simple TMA: (128, 16) descriptor for just the first K sub-tile
|
||||
// If this works, the issue is with (128, 64) descriptors
|
||||
bf16_t* d_k_sub;
|
||||
cudaMalloc(&d_k_sub, SK * MMA_K_BF16 * sizeof(bf16_t));
|
||||
// Copy first sub-tile of K
|
||||
cudaMemcpy(d_k_sub, d_k, SK * MMA_K_BF16 * sizeof(bf16_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
CUtensorMap tma_k;
|
||||
CUtensorMap* d_tma_k;
|
||||
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, BLOCK_MN, MMA_K_BF16)) {
|
||||
if (!create_tma_desc_2d_bf16(&tma_k, d_k_sub, SK, (uint64_t)MMA_K_BF16, BLOCK_MN, MMA_K_BF16)) {
|
||||
printf("TMA K desc FAILED\n"); return 1;
|
||||
}
|
||||
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
||||
|
||||
Reference in New Issue
Block a user