Match tensor slicing exactly to test_d1_kv_merge (2D slices, 3D unsqueeze)
This commit is contained in:
@@ -127,6 +127,11 @@ def _attention_single_head(
|
||||
apply_swa_mask = swa_len is not None
|
||||
apply_sink_bias = sink_bias is not None
|
||||
|
||||
# Convert to 2D/3D shapes matching test_d1_kv_merge.py
|
||||
q_3d = q[0].contiguous().unsqueeze(-1) # (T, hd, 1)
|
||||
k_3d = k[0].contiguous().unsqueeze(-1) # (N, hd, 1)
|
||||
v_2d = v[0].contiguous() # (N, hd)
|
||||
|
||||
# Segment the KV sequence into 128-entry tiles
|
||||
s_k_per_seg = 128
|
||||
n_segments = (N + s_k_per_seg - 1) // s_k_per_seg
|
||||
@@ -144,17 +149,19 @@ def _attention_single_head(
|
||||
o_accum = torch.zeros(T, hd, dtype=torch.float32, device='cuda')
|
||||
lse_accum = torch.full((T, 1), float('-inf'), dtype=torch.float32, device='cuda')
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
for seg in range(n_segments):
|
||||
k_start = seg * s_k_per_seg
|
||||
k_end = min(k_start + s_k_per_seg, N)
|
||||
k_seg = k[:, k_start:k_end]
|
||||
v_seg = v[:, k_start:k_end]
|
||||
k_seg = k_3d[k_start:k_end] # (s_k, hd, 1) — same as test
|
||||
v_seg = v_2d[k_start:k_end] # (s_k, hd) — same as test
|
||||
|
||||
# Pad last segment if shorter than s_k_per_seg
|
||||
if k_end - k_start < s_k_per_seg:
|
||||
pad_len = s_k_per_seg - (k_end - k_start)
|
||||
k_seg = torch.cat([k_seg, torch.zeros(1, pad_len, hd, dtype=k.dtype, device='cuda')], dim=1)
|
||||
v_seg = torch.cat([v_seg, torch.zeros(1, pad_len, hd, dtype=v.dtype, device='cuda')], dim=1)
|
||||
k_seg = torch.cat([k_seg, torch.zeros(pad_len, hd, 1, dtype=k_seg.dtype, device='cuda')], dim=0)
|
||||
v_seg = torch.cat([v_seg, torch.zeros(pad_len, hd, dtype=v_seg.dtype, device='cuda')], dim=0)
|
||||
|
||||
seg_o = torch.zeros(T, hd, dtype=torch.float32, device='cuda')
|
||||
seg_lse = torch.zeros(T, 1, dtype=torch.float32, device='cuda')
|
||||
@@ -162,20 +169,13 @@ def _attention_single_head(
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v_seg[0, :, v_start:v_end].contiguous() # (T, pv_n_tile) 2D
|
||||
v_kernel = v_tile.unsqueeze(-1) # (T, pv_n_tile, 1)
|
||||
v_tile = v_seg[:, v_start:v_end].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
# Match test_d1_kv_merge shapes exactly
|
||||
q_input = q[0].contiguous() # (T, hd) 2D
|
||||
k_input = k_seg[0].contiguous() # (s_k, hd) 2D
|
||||
q_3d = q_input.unsqueeze(-1) # (T, hd, 1)
|
||||
k_3d = k_input.unsqueeze(-1) # (s_k, hd, 1)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_3d))
|
||||
mK = ct.from_dlpack(k_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_3d))
|
||||
mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
|
||||
Reference in New Issue
Block a user