cute.size(gK, mode=[3]) returns 1 for ALL n values — mode 3 is batch, not KV tiles. self.n_kv_tiles = s_k // 128 is the correct Python int. This is why softmax only processed kt=0 for all n.
cute.size(gK, mode=[3]) returns 1 for ALL n values — mode 3 is batch, not KV tiles. self.n_kv_tiles = s_k // 128 is the correct Python int. This is why softmax only processed kt=0 for all n.