SMEM-P: try transposed mapping (swap m/n)
This commit is contained in:
@@ -354,10 +354,17 @@ class FmhaKernel:
|
||||
n = qk_coord[1]
|
||||
|
||||
# Map to PV SMEM coordinate
|
||||
n0 = n % 16
|
||||
n1 = (n // 16) % 4
|
||||
n2 = n // 64
|
||||
pv_coord = ((m, n0), 0, (n1, n2), 0)
|
||||
# Try transposed mapping (maybe PV expects P^T?)
|
||||
m0 = m % 16
|
||||
m1 = (m // 16) % 4
|
||||
m2 = m // 64
|
||||
pv_coord = ((n, m0), 0, (m1, m2), 0)
|
||||
|
||||
# Original mapping (likely wrong):
|
||||
# n0 = n % 16
|
||||
# n1 = (n // 16) % 4
|
||||
# n2 = n // 64
|
||||
# pv_coord = ((m, n0), 0, (n1, n2), 0)
|
||||
|
||||
# DEBUG: Write linear index as value: m*128 + n
|
||||
# This uniquely identifies each position
|
||||
|
||||
Reference in New Issue
Block a user