diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 58a63573..941454f6 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -127,6 +127,11 @@ def _attention_single_head( apply_swa_mask = swa_len is not None apply_sink_bias = sink_bias is not None + # Convert to 2D/3D shapes matching test_d1_kv_merge.py + q_3d = q[0].contiguous().unsqueeze(-1) # (T, hd, 1) + k_3d = k[0].contiguous().unsqueeze(-1) # (N, hd, 1) + v_2d = v[0].contiguous() # (N, hd) + # Segment the KV sequence into 128-entry tiles s_k_per_seg = 128 n_segments = (N + s_k_per_seg - 1) // s_k_per_seg @@ -144,17 +149,19 @@ def _attention_single_head( o_accum = torch.zeros(T, hd, dtype=torch.float32, device='cuda') lse_accum = torch.full((T, 1), float('-inf'), dtype=torch.float32, device='cuda') + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + for seg in range(n_segments): k_start = seg * s_k_per_seg k_end = min(k_start + s_k_per_seg, N) - k_seg = k[:, k_start:k_end] - v_seg = v[:, k_start:k_end] + k_seg = k_3d[k_start:k_end] # (s_k, hd, 1) — same as test + v_seg = v_2d[k_start:k_end] # (s_k, hd) — same as test # Pad last segment if shorter than s_k_per_seg if k_end - k_start < s_k_per_seg: pad_len = s_k_per_seg - (k_end - k_start) - k_seg = torch.cat([k_seg, torch.zeros(1, pad_len, hd, dtype=k.dtype, device='cuda')], dim=1) - v_seg = torch.cat([v_seg, torch.zeros(1, pad_len, hd, dtype=v.dtype, device='cuda')], dim=1) + k_seg = torch.cat([k_seg, torch.zeros(pad_len, hd, 1, dtype=k_seg.dtype, device='cuda')], dim=0) + v_seg = torch.cat([v_seg, torch.zeros(pad_len, hd, dtype=v_seg.dtype, device='cuda')], dim=0) seg_o = torch.zeros(T, hd, dtype=torch.float32, device='cuda') seg_lse = torch.zeros(T, 1, dtype=torch.float32, device='cuda') @@ -162,20 +169,13 @@ def _attention_single_head( for nt in range(n_pv_tiles): v_start = nt * pv_n_tile v_end = v_start + pv_n_tile - v_tile = v_seg[0, :, v_start:v_end].contiguous() # (T, pv_n_tile) 2D - v_kernel = v_tile.unsqueeze(-1) # (T, pv_n_tile, 1) + v_tile = v_seg[:, v_start:v_end].contiguous() + v_kernel = v_tile.unsqueeze(-1) c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') - # Match test_d1_kv_merge shapes exactly - q_input = q[0].contiguous() # (T, hd) 2D - k_input = k_seg[0].contiguous() # (s_k, hd) 2D - q_3d = q_input.unsqueeze(-1) # (T, hd, 1) - k_3d = k_input.unsqueeze(-1) # (s_k, hd, 1) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_3d)) - mK = ct.from_dlpack(k_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_3d)) + mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))