From 2252d7c86539ca7ac7fccc5cedca30e1632068f5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 11:00:36 +0000 Subject: [PATCH] fix: make K/V segments contiguous before passing to kernel (TMA needs contiguous tensors) --- tests/unit/test_d5b_perrow_lse.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_d5b_perrow_lse.py b/tests/unit/test_d5b_perrow_lse.py index 156cbcf8..886943dc 100644 --- a/tests/unit/test_d5b_perrow_lse.py +++ b/tests/unit/test_d5b_perrow_lse.py @@ -167,8 +167,8 @@ def test_lse_kv_merge(): lses = [] for seg in range(s_k // seg_size): - k_seg = k[seg * seg_size:(seg + 1) * seg_size] - v_seg = v[seg * seg_size:(seg + 1) * seg_size] + k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous() + v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous() k_seg_3d = k_seg.unsqueeze(-1) o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd) @@ -208,8 +208,8 @@ def test_lse_kv_merge_4tiles(): lses = [] for seg in range(s_k // seg_size): - k_seg = k[seg * seg_size:(seg + 1) * seg_size] - v_seg = v[seg * seg_size:(seg + 1) * seg_size] + k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous() + v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous() k_seg_3d = k_seg.unsqueeze(-1) o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)