SMEM-P: debug with test pattern (k+j)*0.01
This commit is contained in:
@@ -359,8 +359,12 @@ class FmhaKernel:
|
||||
n2 = n // 64
|
||||
pv_coord = ((m, n0), 0, (n1, n2), 0)
|
||||
|
||||
# Convert Float32 → BF16 and write to SMEM
|
||||
p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
# DEBUG: Write simple test pattern (k+j)*0.01
|
||||
# This helps verify coordinate mapping
|
||||
# k and j are loop indices (0-31, 0-3)
|
||||
pattern_val = Float32(k + j) * Float32(0.01)
|
||||
p_val_bf16 = pattern_val.to(self.q_dtype)
|
||||
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
|
||||
sP[pv_coord] = p_val_bf16
|
||||
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
|
||||
Reference in New Issue
Block a user