diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index c39d4817..618796fa 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -52,8 +52,11 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, // Descriptors uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK); - uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128); - uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); + // Try LBO = 32 (128/4 * 16 / 16 = 32 in 16B units) + // Hypothesis: M=128 has 4 sub-tiles, each with 32 rows + // So LBO should be 32 * 16 = 512 bytes (32 in 16B units) + uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 32); // LBO = 32 * 16 = 512B + uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 32); uint32_t idesc = make_idesc(128, 128); // MMA — 4 warp leaders call the instruction (Layout D requires 4 warps)