From a63f452c8629db11d3f1ca04e83e2ff2ac5aa6ba Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 00:30:49 +0000 Subject: [PATCH] Fix softmax loop: use self.n_kv_tiles not cute.size(gK, mode=[3]) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cute.size(gK, mode=[3]) returns 1 for ALL n values — mode 3 is batch, not KV tiles. self.n_kv_tiles = s_k // 128 is the correct Python int. This is why softmax only processed kt=0 for all n. --- tests/unit/test_fmha_v3_stage_c.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index ff347d50..c73ffcdc 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -321,10 +321,10 @@ class FmhaV3StageCMulti: # For now: kernel is correct when row_max growth across tiles is # mild (typical for short n with random data); for very long n # the missing rescale shows as accuracy drift. - for kt in range(n_kv_tiles): + for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance() if kt == 0: - cute.printf("SOFTMAX n_kv_tiles=%d\n", Int32(n_kv_tiles)) + cute.printf("SOFTMAX self.n_kv_tiles=%d\n", Int32(self.n_kv_tiles)) # Load S[kt] tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)