Fix softmax loop: use self.n_kv_tiles not cute.size(gK, mode=[3])

cute.size(gK, mode=[3]) returns 1 for ALL n values — mode 3 is batch,
not KV tiles. self.n_kv_tiles = s_k // 128 is the correct Python int.
This is why softmax only processed kt=0 for all n.
This commit is contained in:
2026-05-23 00:30:49 +00:00
parent 195b0506af
commit a63f452c86

View File

@@ -321,10 +321,10 @@ class FmhaV3StageCMulti:
# For now: kernel is correct when row_max growth across tiles is
# mild (typical for short n with random data); for very long n
# the missing rescale shows as accuracy drift.
for kt in range(n_kv_tiles):
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
if kt == 0:
cute.printf("SOFTMAX n_kv_tiles=%d\n", Int32(n_kv_tiles))
cute.printf("SOFTMAX self.n_kv_tiles=%d\n", Int32(self.n_kv_tiles))
# Load S[kt]
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)