Fix softmax loop: use self.n_kv_tiles not cute.size(gK, mode=[3])
cute.size(gK, mode=[3]) returns 1 for ALL n values — mode 3 is batch, not KV tiles. self.n_kv_tiles = s_k // 128 is the correct Python int. This is why softmax only processed kt=0 for all n.
This commit is contained in:
@@ -321,10 +321,10 @@ class FmhaV3StageCMulti:
|
||||
# For now: kernel is correct when row_max growth across tiles is
|
||||
# mild (typical for short n with random data); for very long n
|
||||
# the missing rescale shows as accuracy drift.
|
||||
for kt in range(n_kv_tiles):
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
if kt == 0:
|
||||
cute.printf("SOFTMAX n_kv_tiles=%d\n", Int32(n_kv_tiles))
|
||||
cute.printf("SOFTMAX self.n_kv_tiles=%d\n", Int32(self.n_kv_tiles))
|
||||
|
||||
# Load S[kt]
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
|
||||
Reference in New Issue
Block a user