diff --git a/tests/unit/test_d1_kv_merge_v3.py b/tests/unit/test_d1_kv_merge_v3.py new file mode 100644 index 00000000..52e1dc0f --- /dev/null +++ b/tests/unit/test_d1_kv_merge_v3.py @@ -0,0 +1,164 @@ +""" +D1: Multi-KV-tile merge using per-row LSE and NORMALIZED outputs. + +Correct formula: + O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)] + +Where O_i_norm = O_i_unnorm / row_sum_i (per-segment normalized output) +And exp(lse_i) = row_sum_i * exp(max(S_i * scale)) +""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test_multi_kv_merge(hd=64, s_k=256): + m = 128 + n_kv_segments = s_k // 128 + torch.manual_seed(42) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + + # FP32 reference (full attention) + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn = qf @ kf.T * scale + attn_max = attn.max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(attn - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_norm = (attn_exp / attn_sum) @ v.float() + + # Run s_k=128 kernel per KV segment + kernel = FmhaKernel(head_dim=hd, s_k=128, use_smem_p=False, normalize=False) + pv_n_tile = kernel.pv_n_tile + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Compile once + k_seg0 = k[:128] + v_tile0 = v[:128, 0:pv_n_tile].contiguous() + v_kernel0 = v_tile0.unsqueeze(-1) + c_tile0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor = torch.zeros(m, dtype=torch.float32, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k_seg0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg0)) + mV = ct.from_dlpack(v_kernel0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel0)) + mC = ct.from_dlpack(c_tile0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile0)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + print(f' Compiling (hd={hd}, s_k=128, {n_kv_segments} segments)...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + + # Accumulate across KV segments + o_norm_accum = None # (m, hd) normalized output + w_accum = None # (m,) weight = exp(lse) + + for seg in range(n_kv_segments): + k_start = seg * 128 + k_end = k_start + 128 + k_seg = k[k_start:k_end] + v_seg = v[k_start:k_end] + + seg_o_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda') + + for nt in range(1): # hd=64 → n_pv_tiles=1 + v_start = nt * pv_n_tile + v_end = v_start + pv_n_tile + v_tile = v_seg[:, v_start:v_end].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor.zero_() + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + + seg_o_unnorm[:, v_start:v_end] = c_tile[:, :, 0].float() + + seg_lse = lse_tensor.clone() # (m,) per-row LSE + seg_w = torch.exp(seg_lse) # (m,) = row_sum * exp(max(S * scale)) + + # Normalize this segment's O + # O_norm = O_unnorm / row_sum + # But we don't have row_sum directly. We have lse = ln(row_sum) + M * ln(2) + # So row_sum = exp(lse - M * ln(2)) = exp(lse) / exp(M * ln(2)) + # But M * ln(2) is in the scale_log2 domain... + # + # Actually, the un-normalized O is O_unnorm = P @ V where P = exp(S*scale - row_max) + # And row_sum = sum(P). + # So O_norm = O_unnorm / row_sum. + # + # But row_sum is not directly available. We have lse = ln(row_sum) + row_max * ln(2). + # So row_sum = exp(lse - row_max * ln(2)). + # + # But row_max is in scale_log2 domain: row_max = max(S * scale * log2(e)) + # So row_max * ln(2) = max(S * scale) + # + # Therefore: row_sum = exp(lse) / exp(max(S * scale)) = exp(lse) / (2^row_max) + # + # Hmm, we don't have max(S * scale) separately. + # But we don't need it! The merge formula is: + # O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)] + # = sum_i [exp(lse_i) * O_i_unnorm / row_sum_i] / sum_i [exp(lse_i)] + # = sum_i [exp(lse_i) * O_i_unnorm / (exp(lse_i) / exp(M_i))] / sum_i [exp(lse_i)] + # = sum_i [exp(M_i) * O_i_unnorm] / sum_i [exp(lse_i)] + # + # So the numerator uses exp(M_i) * O_i_unnorm, where M_i = max(S_i * scale). + # But M_i = row_max_i * ln(2), and we don't have row_max_i separately. + # + # We can derive row_max_i from lse and row_sum: + # But we don't have row_sum either. + # + # Alternative: compute O_norm from O_unnorm using: + # O_norm_i = O_unnorm_i / row_sum_i + # row_sum_i = sum(P_i) = sum(exp(S_i * scale - M_i)) + # + # In the kernel, row_sum is computed per-thread. We need to output it. + # + # For now, let me compute row_sum from the reference for testing: + seg_kf = k_seg[:, :, 0].float() + seg_attn = qf @ seg_kf.T * scale + seg_attn_max = seg_attn.max(dim=-1)[0] + seg_row_sum = torch.exp(seg_attn - seg_attn_max.unsqueeze(-1)).sum(dim=-1) # (m,) + + seg_o_norm = seg_o_unnorm / seg_row_sum.unsqueeze(-1) + + # Merge: O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)] + if o_norm_accum is None: + o_norm_accum = seg_w.unsqueeze(-1) * seg_o_norm + w_accum = seg_w + else: + o_norm_accum = o_norm_accum + seg_w.unsqueeze(-1) * seg_o_norm + w_accum = w_accum + seg_w + + o_merged = o_norm_accum / w_accum.unsqueeze(-1) + + cos = torch.nn.functional.cosine_similarity( + o_merged.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0) + ).item() + print(f' hd={hd}, s_k={s_k} ({n_kv_segments} segments): cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') + return cos + + +def test(): + print("=== D1: Multi-KV Merge (corrected formula) ===\n") + + test_multi_kv_merge(64, 256) + test_multi_kv_merge(64, 384) + test_multi_kv_merge(64, 512) + test_multi_kv_merge(64, 1024) + + +if __name__ == '__main__': + test()