diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 96a11b19..8a136b6f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -349,9 +349,9 @@ class FmhaKernel: if self.use_smem_p: # Get QK coordinate for this position qk_coord = tTMEM_LOADcS_frg[k, j] - mn = qk_coord[0] - m = mn[0] - n = mn[1] + # qk_coord is (m, n) coordinate + m = qk_coord[0] + n = qk_coord[1] # Map to PV SMEM coordinate n0 = n % 16