fix: block_idx() returns tuple, use [1] for y

This commit is contained in:
2026-05-24 23:29:59 +00:00
parent 4c79e5533e
commit 2b76b691cb

View File

@@ -177,7 +177,7 @@ class FmhaKernel:
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
# D2: Multi-CTA grid. Use block_idx_y to select Q and O for this CTA's head.
head_cta_idx = cute.arch.block_idx(dim=1) # block_idx_y
_bidx, head_cta_idx, _bidz = cute.arch.block_idx() # grid=(1, num_ctas, 1)
# Q: if num_ctas > 1, mQ has a head dimension. local_tile indexes into it.
# K/V: shared (MQA), always coordinate 0.