diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 95015d29..61869bf6 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -382,12 +382,19 @@ class FmhaKernel: # DEBUG: Print first few coordinates to verify mapping if self.use_smem_p and k < 2 and j < 2: print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}") + # Try to compute offset using crd2idx + try: + offset = cute.crd2idx(pv_coord, sP.layout) + print(f"[SMEM-P DEBUG] offset = {offset}") + except: + print(f"[SMEM-P DEBUG] crd2idx not available") # DEBUG: Also write pattern based on fragment indices (k,j) # If coordinates wrong, this pattern might work better - # pattern_val = Float32(k) + Float32(j) * Float32(32.0) - # pattern_bf16 = pattern_val.to(self.q_dtype) - # sP[pv_coord] = pattern_bf16 + pattern_val = Float32(k) + Float32(j) * Float32(32.0) + p_val_bf16 = pattern_val.to(self.q_dtype) + # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) + sP[pv_coord] = p_val_bf16 # Tensor indexing row_sum = row_sum + tTMEM_LOADrS_frg[k, j] s_vec = tTMEM_LOADrS_frg[None, j].load()