diff --git a/tests/unit/test_d5b_perrow_lse.py b/tests/unit/test_d5b_perrow_lse.py index 0765fff4..156cbcf8 100644 --- a/tests/unit/test_d5b_perrow_lse.py +++ b/tests/unit/test_d5b_perrow_lse.py @@ -161,6 +161,7 @@ def test_lse_kv_merge(): ref_o, _ = reference_attention_with_lse(q[:, :, 0], k[:, :, 0], v, scale) # Kernel: two segments of 128, merge with per-row LSE + normalized O + # IMPORTANT: create kernel with s_k=128 (segment size), not s_k=256 seg_size = 128 o_norms = [] lses = []