diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 44f98718..49522eb0 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -346,7 +346,7 @@ class FmhaKernel: # rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch # Create view with QK C-fragment layout (tStS0.layout) rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread - rP_qk = cute.make_tensor(cute.recast_ptr(rP.iterator, dtype=self.q_dtype), rP_qk_layout) + rP_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout) # Partition source with QK layout tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk)