fix: add missing old_row_max = row_max before softmax max computation
This commit is contained in:
@@ -248,6 +248,7 @@ class FmhaV3StageC:
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Compute row_max from S (element-wise, before exp2)
|
||||
old_row_max = row_max
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
|
||||
Reference in New Issue
Block a user