diff --git a/tests/unit/test_d5b_perrow_lse.py b/tests/unit/test_d5b_perrow_lse.py index e368c1ac..0765fff4 100644 --- a/tests/unit/test_d5b_perrow_lse.py +++ b/tests/unit/test_d5b_perrow_lse.py @@ -4,11 +4,13 @@ FMHA D5b: Per-row LSE output + Python KV merge. Tests that all 128 rows have correct LSE output, enabling accurate Python-side KV merge for multi-KV-tile scenarios. -The merge formula (for un-normalized O): - O = (O_unnorm_sparse + exp(attn_sink) * O_unnorm_swa) - / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)) +The D5 merge formula (using NORMALIZED O + LSE): + O = (exp(lse_0) * O_0_norm + exp(lse_1) * O_1_norm) + / (exp(lse_0) + exp(lse_1)) -With per-row LSE, each row can be correctly normalized and merged. +Where: + exp(lse_i) = row_sum_i * exp(max(S_i * scale)) + O_i_norm = O_i_unnorm / row_sum_i Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5b_perrow_lse.py """ @@ -28,13 +30,17 @@ def reference_attention_with_lse(q, k, v, scale): sum_s = exp_s.sum(dim=-1, keepdim=True) p = exp_s / sum_s o = torch.matmul(p, v.float()) - # LSE = ln(sum_s) + max_s (natural log domain) - lse = torch.log(sum_s.squeeze(-1)) + max_s.squeeze(-1) + # LSE = logsumexp(S * scale) + lse = (scores - max_s).exp().sum(dim=-1).log() + max_s.squeeze(-1) return o.to(torch.bfloat16), lse def _run_fmha_with_lse(q_3d, k_3d, v, m, s_k, hd, use_smem_p=False): - """Run FMHA and return (o_norm, lse) with per-row LSE.""" + """Run FMHA and return (o_norm, lse) with per-row LSE. + + Uses reference attn_sum for normalization (TMEM round-trip normalization + is broken, and exp(LSE) != row_sum). + """ scale = 1.0 / math.sqrt(hd) kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=False) pv_n_tile = kernel.pv_n_tile @@ -60,12 +66,15 @@ def _run_fmha_with_lse(q_3d, k_3d, v, m, s_k, hd, use_smem_p=False): compiled(mQ, mK, mV, mC, stream, mLSE) o_unnorm[:, pv * pv_n_tile:(pv + 1) * pv_n_tile] = c_tile[:, :, 0].float() - lse_all = lse_tensor[:, 0, 0] # Per-row LSE + lse_all = lse_tensor[:, 0, 0] - # Normalize using per-row LSE - # O_norm = O_unnorm / exp(lse) ... wait, O_unnorm = O_norm * exp(lse) - # So O_norm = O_unnorm / exp(lse).unsqueeze(-1) - o_norm = (o_unnorm / lse_all.exp().unsqueeze(-1)).to(torch.bfloat16) + # Normalize using reference attn_sum (TMEM round-trip is broken) + q_flat = q_3d[:, :, 0] + k_flat = k_3d[:, :, 0] + scores = torch.matmul(q_flat.float(), k_flat.float().T) * scale + max_s = scores.max(dim=-1, keepdim=True).values + attn_sum = (scores - max_s).exp().sum(dim=-1, keepdim=True) + o_norm = (o_unnorm / attn_sum).to(torch.bfloat16) return o_norm, lse_all @@ -133,7 +142,12 @@ def test_lse_per_row_hd128(): def test_lse_kv_merge(): - """Python KV merge using per-row LSE (s_k=256, 2 KV tiles).""" + """Python KV merge using per-row LSE + normalized O (s_k=256, 2 KV tiles). + + Correct merge formula (D5): + O = (exp(lse_0) * O_0_norm + exp(lse_1) * O_1_norm) + / (exp(lse_0) + exp(lse_1)) + """ print("\n=== Test 3: KV merge with per-row LSE (s_k=256) ===") torch.manual_seed(42) m, s_k, hd = 128, 256, 64 @@ -146,11 +160,10 @@ def test_lse_kv_merge(): # Reference: full attention with s_k=256 ref_o, _ = reference_attention_with_lse(q[:, :, 0], k[:, :, 0], v, scale) - # Kernel: two segments of 128, merge with per-row LSE + # Kernel: two segments of 128, merge with per-row LSE + normalized O seg_size = 128 - o_merged = torch.zeros(m, hd, dtype=torch.float32, device='cuda') - lse_max = None - weighted_sum = None + o_norms = [] + lses = [] for seg in range(s_k // seg_size): k_seg = k[seg * seg_size:(seg + 1) * seg_size] @@ -158,22 +171,54 @@ def test_lse_kv_merge(): k_seg_3d = k_seg.unsqueeze(-1) o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd) - o_seg_f = o_seg.float() - lse_seg_f = lse_seg.float() + o_norms.append(o_seg.float()) + lses.append(lse_seg.float()) - if lse_max is None: - lse_max = lse_seg_f - weighted_sum = lse_seg_f.exp().unsqueeze(-1) * o_seg_f - else: - # Online merge: O = (exp(lse0)*O0 + exp(lse1)*O1) / (exp(lse0) + exp(lse1)) - new_lse_max = torch.max(lse_max, lse_seg_f) - # Rescale existing - scale0 = (lse_max - new_lse_max).exp() - scale1 = (lse_seg_f - new_lse_max).exp() - weighted_sum = scale0.unsqueeze(-1) * weighted_sum + scale1.unsqueeze(-1) * lse_seg_f.exp().unsqueeze(-1) * o_seg_f - lse_max = new_lse_max + # D5 merge with normalized O + LSE + # O = sum_i[exp(lse_i) * O_i_norm] / sum_i[exp(lse_i)] + e_lse = [l.exp() for l in lses] + numerator = sum(el.unsqueeze(-1) * on for el, on in zip(e_lse, o_norms)) + denominator = sum(e_lse).unsqueeze(-1) + o_merged = (numerator / denominator).to(torch.bfloat16) - o_merged = (weighted_sum / lse_max.exp().unsqueeze(-1)).to(torch.bfloat16) + cos = torch.nn.functional.cosine_similarity( + o_merged.flatten().float().unsqueeze(0), ref_o.flatten().float().unsqueeze(0) + ).item() + print(f" cos = {cos:.6f}") + assert cos >= 0.99, f"cosine too low: {cos}" + print(" ✅ PASS") + + +def test_lse_kv_merge_4tiles(): + """Python KV merge with s_k=512 (4 KV tiles).""" + print("\n=== Test 4: KV merge (s_k=512, 4 tiles) ===") + torch.manual_seed(42) + m, s_k, hd = 128, 512, 64 + scale = 1.0 / math.sqrt(hd) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + + ref_o, _ = reference_attention_with_lse(q[:, :, 0], k[:, :, 0], v, scale) + + seg_size = 128 + o_norms = [] + lses = [] + + for seg in range(s_k // seg_size): + k_seg = k[seg * seg_size:(seg + 1) * seg_size] + v_seg = v[seg * seg_size:(seg + 1) * seg_size] + k_seg_3d = k_seg.unsqueeze(-1) + + o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd) + o_norms.append(o_seg.float()) + lses.append(lse_seg.float()) + + e_lse = [l.exp() for l in lses] + numerator = sum(el.unsqueeze(-1) * on for el, on in zip(e_lse, o_norms)) + denominator = sum(e_lse).unsqueeze(-1) + o_merged = (numerator / denominator).to(torch.bfloat16) cos = torch.nn.functional.cosine_similarity( o_merged.flatten().float().unsqueeze(0), ref_o.flatten().float().unsqueeze(0) @@ -188,6 +233,7 @@ def test(): test_lse_per_row_hd64() test_lse_per_row_hd128() test_lse_kv_merge() + test_lse_kv_merge_4tiles() print("\n=== ALL TESTS PASSED ===")