diff --git a/tests/unit/test_d5b_perrow_lse.py b/tests/unit/test_d5b_perrow_lse.py index 156cbcf8..886943dc 100644 --- a/tests/unit/test_d5b_perrow_lse.py +++ b/tests/unit/test_d5b_perrow_lse.py @@ -167,8 +167,8 @@ def test_lse_kv_merge(): lses = [] for seg in range(s_k // seg_size): - k_seg = k[seg * seg_size:(seg + 1) * seg_size] - v_seg = v[seg * seg_size:(seg + 1) * seg_size] + k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous() + v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous() k_seg_3d = k_seg.unsqueeze(-1) o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd) @@ -208,8 +208,8 @@ def test_lse_kv_merge_4tiles(): lses = [] for seg in range(s_k // seg_size): - k_seg = k[seg * seg_size:(seg + 1) * seg_size] - v_seg = v[seg * seg_size:(seg + 1) * seg_size] + k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous() + v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous() k_seg_3d = k_seg.unsqueeze(-1) o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)