SMEM-P: test pattern based on fragment indices (k,j)
This commit is contained in:
@@ -372,13 +372,22 @@ class FmhaKernel:
|
||||
n2 = n_local // 64
|
||||
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
|
||||
|
||||
# Write actual P value (not test pattern)
|
||||
p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
# DEBUG: Write pattern based on fragment indices (k,j)
|
||||
# If coordinates wrong, this pattern might work better
|
||||
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
|
||||
p_val_bf16 = pattern_val.to(self.q_dtype)
|
||||
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16 # Tensor indexing
|
||||
|
||||
# DEBUG: Print first few coordinates to verify mapping
|
||||
if self.use_smem_p and k < 2 and j < 2:
|
||||
print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}")
|
||||
|
||||
# DEBUG: Also write pattern based on fragment indices (k,j)
|
||||
# If coordinates wrong, this pattern might work better
|
||||
# pattern_val = Float32(k) + Float32(j) * Float32(32.0)
|
||||
# pattern_bf16 = pattern_val.to(self.q_dtype)
|
||||
# sP[pv_coord] = pattern_bf16
|
||||
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
|
||||
Reference in New Issue
Block a user