fix: use plain range loop for row_max (fmax not allowed in vectorized)
This commit is contained in:
@@ -253,7 +253,7 @@ class FmhaV3StageC:
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
|
||||
|
||||
row_max_safe = row_max
|
||||
|
||||
Reference in New Issue
Block a user