Fix rP scope issue: use rP_bf16.iterator instead of rP.iterator
This commit is contained in:
@@ -346,7 +346,7 @@ class FmhaKernel:
|
||||
# rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch
|
||||
# Create view with QK C-fragment layout (tStS0.layout)
|
||||
rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread
|
||||
rP_qk = cute.make_tensor(cute.recast_ptr(rP.iterator, dtype=self.q_dtype), rP_qk_layout)
|
||||
rP_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout)
|
||||
|
||||
# Partition source with QK layout
|
||||
tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk)
|
||||
|
||||
Reference in New Issue
Block a user