diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index b7db01de..4030749f 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -41,7 +41,7 @@ test_umma_qk_hd16( // TMEM alloc if (wid == 0) { uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(smem_ptr, 32); + tmem_alloc(smem_ptr, 128); // 128 columns for N=128 } __syncthreads(); uint32_t tmem_base = *sTmemBase; @@ -56,7 +56,7 @@ test_umma_qk_hd16( uint32_t sK_smem = __cvta_generic_to_shared(sK); uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128); uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); - uint32_t idesc = make_idesc(128, 32); + uint32_t idesc = make_idesc(128, 128); // Try N=128 (full extent) // Verify SMEM Q and K by reading back row 0 if (tid == 0) { @@ -104,7 +104,7 @@ test_umma_qk_hd16( } __syncthreads(); - if (wid == 0) tmem_dealloc(tmem_base, 32); + if (wid == 0) tmem_dealloc(tmem_base, 128); } int main() {