SMEM-P: debug with linear index pattern m*128+n

This commit is contained in:
2026-05-23 19:52:46 +00:00
parent 81630037bd
commit 3d044b4747

View File

@@ -359,10 +359,11 @@ class FmhaKernel:
n2 = n // 64
pv_coord = ((m, n0), 0, (n1, n2), 0)
# DEBUG: Write simple test pattern (k+j)*0.01
# This helps verify coordinate mapping
# k and j are loop indices (0-31, 0-3)
pattern_val = Float32(k + j) * Float32(0.01)
# DEBUG: Write linear index as value: m*128 + n
# This uniquely identifies each position
linear_idx = m * 128 + n
# Convert to Float32 (values 0-16383)
pattern_val = Float32(linear_idx)
p_val_bf16 = pattern_val.to(self.q_dtype)
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16