diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 043bfbf1..5f4dbc47 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -354,10 +354,17 @@ class FmhaKernel: n = qk_coord[1] # Map to PV SMEM coordinate - n0 = n % 16 - n1 = (n // 16) % 4 - n2 = n // 64 - pv_coord = ((m, n0), 0, (n1, n2), 0) + # Try transposed mapping (maybe PV expects P^T?) + m0 = m % 16 + m1 = (m // 16) % 4 + m2 = m // 64 + pv_coord = ((n, m0), 0, (m1, m2), 0) + + # Original mapping (likely wrong): + # n0 = n % 16 + # n1 = (n // 16) % 4 + # n2 = n // 64 + # pv_coord = ((m, n0), 0, (n1, n2), 0) # DEBUG: Write linear index as value: m*128 + n # This uniquely identifies each position