fix: make K/V segments contiguous before passing to kernel (TMA needs contiguous tensors)
This commit is contained in:
@@ -167,8 +167,8 @@ def test_lse_kv_merge():
|
||||
lses = []
|
||||
|
||||
for seg in range(s_k // seg_size):
|
||||
k_seg = k[seg * seg_size:(seg + 1) * seg_size]
|
||||
v_seg = v[seg * seg_size:(seg + 1) * seg_size]
|
||||
k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous()
|
||||
v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous()
|
||||
k_seg_3d = k_seg.unsqueeze(-1)
|
||||
|
||||
o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)
|
||||
@@ -208,8 +208,8 @@ def test_lse_kv_merge_4tiles():
|
||||
lses = []
|
||||
|
||||
for seg in range(s_k // seg_size):
|
||||
k_seg = k[seg * seg_size:(seg + 1) * seg_size]
|
||||
v_seg = v[seg * seg_size:(seg + 1) * seg_size]
|
||||
k_seg = k[seg * seg_size:(seg + 1) * seg_size].contiguous()
|
||||
v_seg = v[seg * seg_size:(seg + 1) * seg_size].contiguous()
|
||||
k_seg_3d = k_seg.unsqueeze(-1)
|
||||
|
||||
o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)
|
||||
|
||||
Reference in New Issue
Block a user