Match tensor slicing exactly to test_d1_kv_merge (2D slices, 3D unsqueeze)

This commit is contained in:
2026-05-27 06:58:28 +00:00
parent 6ee61717c0
commit 8f8d14c300

View File

@@ -127,6 +127,11 @@ def _attention_single_head(
apply_swa_mask = swa_len is not None
apply_sink_bias = sink_bias is not None
# Convert to 2D/3D shapes matching test_d1_kv_merge.py
q_3d = q[0].contiguous().unsqueeze(-1) # (T, hd, 1)
k_3d = k[0].contiguous().unsqueeze(-1) # (N, hd, 1)
v_2d = v[0].contiguous() # (N, hd)
# Segment the KV sequence into 128-entry tiles
s_k_per_seg = 128
n_segments = (N + s_k_per_seg - 1) // s_k_per_seg
@@ -144,17 +149,19 @@ def _attention_single_head(
o_accum = torch.zeros(T, hd, dtype=torch.float32, device='cuda')
lse_accum = torch.full((T, 1), float('-inf'), dtype=torch.float32, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
for seg in range(n_segments):
k_start = seg * s_k_per_seg
k_end = min(k_start + s_k_per_seg, N)
k_seg = k[:, k_start:k_end]
v_seg = v[:, k_start:k_end]
k_seg = k_3d[k_start:k_end] # (s_k, hd, 1) — same as test
v_seg = v_2d[k_start:k_end] # (s_k, hd) — same as test
# Pad last segment if shorter than s_k_per_seg
if k_end - k_start < s_k_per_seg:
pad_len = s_k_per_seg - (k_end - k_start)
k_seg = torch.cat([k_seg, torch.zeros(1, pad_len, hd, dtype=k.dtype, device='cuda')], dim=1)
v_seg = torch.cat([v_seg, torch.zeros(1, pad_len, hd, dtype=v.dtype, device='cuda')], dim=1)
k_seg = torch.cat([k_seg, torch.zeros(pad_len, hd, 1, dtype=k_seg.dtype, device='cuda')], dim=0)
v_seg = torch.cat([v_seg, torch.zeros(pad_len, hd, dtype=v_seg.dtype, device='cuda')], dim=0)
seg_o = torch.zeros(T, hd, dtype=torch.float32, device='cuda')
seg_lse = torch.zeros(T, 1, dtype=torch.float32, device='cuda')
@@ -162,20 +169,13 @@ def _attention_single_head(
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v_seg[0, :, v_start:v_end].contiguous() # (T, pv_n_tile) 2D
v_kernel = v_tile.unsqueeze(-1) # (T, pv_n_tile, 1)
v_tile = v_seg[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
# Match test_d1_kv_merge shapes exactly
q_input = q[0].contiguous() # (T, hd) 2D
k_input = k_seg[0].contiguous() # (s_k, hd) 2D
q_3d = q_input.unsqueeze(-1) # (T, hd, 1)
k_3d = k_input.unsqueeze(-1) # (s_k, hd, 1)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_3d))
mK = ct.from_dlpack(k_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_3d))
mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))