test: try LBO with block_mn=32 (1/4 of M=128)

This commit is contained in:
2026-05-28 10:11:38 +00:00
parent d03e353972
commit 3f95f1c5d4

View File

@@ -52,8 +52,11 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
// Descriptors
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128);
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128);
// Try LBO = 32 (128/4 * 16 / 16 = 32 in 16B units)
// Hypothesis: M=128 has 4 sub-tiles, each with 32 rows
// So LBO should be 32 * 16 = 512 bytes (32 in 16B units)
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 32); // LBO = 32 * 16 = 512B
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 32);
uint32_t idesc = make_idesc(128, 128);
// MMA — 4 warp leaders call the instruction (Layout D requires 4 warps)