From e6c9e6c0d034a55c8fa3c49ac373348850d19431 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 16:31:06 +0000 Subject: [PATCH] D1.4: Add external k_sub merge test for hd=512 (avoids slow in-kernel k_sub compilation) --- tests/unit/test_d1_hd512_merge.py | 137 ++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 tests/unit/test_d1_hd512_merge.py diff --git a/tests/unit/test_d1_hd512_merge.py b/tests/unit/test_d1_hd512_merge.py new file mode 100644 index 00000000..fcdd061f --- /dev/null +++ b/tests/unit/test_d1_hd512_merge.py @@ -0,0 +1,137 @@ +"""D1.4 hd=512 test using external k_sub merge. + +Instead of the k_sub path in the kernel (which causes 45+ min compilation), +we call the kernel once per k_sub tile with Q and K pre-sliced. +The online softmax merge (same as D5) combines the partial results. + +The kernel always runs at k_tile=256 (same as hd=256, proven to compile fast). +""" +import torch, math, time +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test(): + torch.manual_seed(42) + hd, n = 512, 128 + m = 128 + k_tile = 256 + n_k_sub = hd // k_tile # 2 + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') + + # FP32 reference (full attention) + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_unnorm = attn_exp @ v.float() + ref_norm = (attn_exp / attn_sum) @ v.float() + ref_lse = (torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1))[0].item() + + # Use the hd=256 kernel (no k_sub path) with k_tile=256 + # Call once per k_sub tile, merge results via online softmax + kernel = FmhaKernel(head_dim=k_tile, s_k=n, normalize=False) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + print(f'hd={hd}, k_tile={k_tile}, n_k_sub={n_k_sub}, pv_n_tile={pv_n_tile}', flush=True) + print(f'Compiling k_tile={k_tile} kernel...', flush=True) + + # Compile once with the first k_sub tile + q0 = q[:, 0:k_tile, :].contiguous() + k0 = k[:, 0:k_tile, :].contiguous() + v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + mQ0 = ct.from_dlpack(q0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q0)) + mK0 = ct.from_dlpack(k0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k0)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + t0 = time.time() + compiled = cute.compile(kernel, mQ0, mK0, mV, mC, stream, mLSE) + t1 = time.time() + print(f'Compilation took {t1-t0:.1f}s', flush=True) + + # Run each k_sub tile and accumulate via online softmax merge + # LSE_i = ln(sum(exp(S_i - m_i))) + m_i (in natural log domain) + # Merge: O = (exp(LSE_0 - LSE_max) * O_0 + exp(LSE_1 - LSE_max) * O_1) / + # (exp(LSE_0 - LSE_max) + exp(LSE_1 - LSE_max)) + # where LSE_max = max(LSE_0, LSE_1) + + # Collect (un-norm O, LSE) for each k_sub and each pv_tile + all_o_unnorm = [] # list of (n_k_sub, hd) tensors + all_lse = [] # list of (n_k_sub,) LSE values + + for ks in range(n_k_sub): + ks_start = ks * k_tile + ks_end = ks_start + k_tile + q_ks = q[:, ks_start:ks_end, :].contiguous() + k_ks = k[:, ks_start:ks_end, :].contiguous() + + o_ks = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + lse_ks = None + + for nt in range(n_pv_tiles): + v_start = nt * pv_n_tile + v_end = v_start + pv_n_tile + v_tile = v[:, v_start:v_end].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor.zero_() + + mQ = ct.from_dlpack(q_ks).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_ks)) + mK = ct.from_dlpack(k_ks).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_ks)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + + o_ks[:, v_start:v_end, :] = c_tile + if nt == 0: + lse_ks = lse_tensor[0, 0, 0].item() + + all_o_unnorm.append(o_ks[:, :, 0].float()) + all_lse.append(lse_ks) + print(f' k_sub={ks}: lse={lse_ks:.4f}', flush=True) + + # Online softmax merge + # O_unnorm_full = sum_ks exp(lse_ks - lse_max) * O_ks + # Normalization: O_norm = O_unnorm_full / sum_ks exp(lse_ks - lse_max) + lse_max = max(all_lse) + o_merged_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda') + denom = 0.0 + for ks in range(n_k_sub): + w = math.exp(all_lse[ks] - lse_max) + o_merged_unnorm += w * all_o_unnorm[ks] + denom += w + + o_merged_norm = o_merged_unnorm / denom + + cos_unnorm = torch.nn.functional.cosine_similarity( + o_merged_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) + ).item() + cos_norm = torch.nn.functional.cosine_similarity( + o_merged_norm.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0) + ).item() + + status = "PASS" if cos_norm >= 0.99 else "FAIL" + print(f'\nhd=512 (external k_sub merge): cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f} {status}') + if cos_norm < 0.99: + print(f' o_merged[0,:4]={o_merged_norm[0,:4].tolist()}') + print(f' ref[0,:4]={ref_norm[0,:4].tolist()}') + + +if __name__ == '__main__': + test()