diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index a75de0d5..9e6e59be 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -66,12 +66,12 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, // Descriptors uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK); - uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128); - uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); + uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 64); // M=64 + uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 64); // M=64 // Test with different N values to understand scaling // N=8 → n_dim=1 - uint32_t idesc = make_idesc(128, 8); // Try N=8 + uint32_t idesc = make_idesc(64, 128); // Try M=64 if (tid == 0) { memcpy(&s_out[128], &desc_q, 8);