fix: make K/V segments contiguous before passing to kernel (TMA needs contiguous tensors)

This commit is contained in:
2026-05-26 11:00:36 +00:00
parent 5407dc768a
commit 2252d7c865

View File

@@ -167,8 +167,8 @@ def test_lse_kv_merge():
lses = []
for seg in range(s_k // seg_size):
k_seg = k[seg * seg_size:(seg + 1) * seg_size]
v_seg = v[seg * seg_size:(seg + 1) * seg_size]
k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous()
v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous()
k_seg_3d = k_seg.unsqueeze(-1)
o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)
@@ -208,8 +208,8 @@ def test_lse_kv_merge_4tiles():
lses = []
for seg in range(s_k // seg_size):
k_seg = k[seg * seg_size:(seg + 1) * seg_size]
v_seg = v[seg * seg_size:(seg + 1) * seg_size]
k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous()
v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous()
k_seg_3d = k_seg.unsqueeze(-1)
o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)