fix: block_idx() returns tuple, use [1] for y
This commit is contained in:
@@ -177,7 +177,7 @@ class FmhaKernel:
|
||||
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
|
||||
|
||||
# D2: Multi-CTA grid. Use block_idx_y to select Q and O for this CTA's head.
|
||||
head_cta_idx = cute.arch.block_idx(dim=1) # block_idx_y
|
||||
_bidx, head_cta_idx, _bidz = cute.arch.block_idx() # grid=(1, num_ctas, 1)
|
||||
|
||||
# Q: if num_ctas > 1, mQ has a head dimension. local_tile indexes into it.
|
||||
# K/V: shared (MQA), always coordinate 0.
|
||||
|
||||
Reference in New Issue
Block a user