From fc0f4bcf23d2e72c26fa7155e94d27e61ca1a29d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 15:09:49 +0000 Subject: [PATCH] diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue --- tests/unit/test_d5c_fused.py | 41 +++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_d5c_fused.py b/tests/unit/test_d5c_fused.py index b72ac839..d99994b2 100644 --- a/tests/unit/test_d5c_fused.py +++ b/tests/unit/test_d5c_fused.py @@ -298,22 +298,18 @@ def test_d5c_with_causal(): if __name__ == '__main__': - # First: baseline test without sink bias (just multi-tile D3 masking) - print('=== Baseline: multi-tile attention with D3 mask (no D5c) ===\n') - _hd = 64; _m = 128; _s_k = 256; _swa_len = 192 # 192 valid, 64 masked + # Test 0: baseline with s_k=128 (single KV tile, no O rescale) + print('=== Baseline: single-tile D3 mask (s_k=128, no D5c) ===\n') + _hd = 64; _m = 128; _s_k = 128; _swa_len = 64 _scale = 1.0 / math.sqrt(_hd) torch.manual_seed(42) _q = torch.randn(_m, _hd, 1, dtype=torch.bfloat16, device='cuda') _k = torch.randn(_s_k, _hd, 1, dtype=torch.bfloat16, device='cuda') _v = torch.randn(_s_k, _hd, dtype=torch.bfloat16, device='cuda') - - # Reference _qf = _q[:, :, 0].float(); _kf = _k[:, :, 0].float(); _vf = _v.float() _scores = _qf @ _kf.T * _scale _scores[:, _swa_len:] = float('-inf') _ref = (torch.softmax(_scores, dim=-1) @ _vf).to(torch.bfloat16) - - # Kernel _stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) _kernel = FmhaKernel(head_dim=_hd, s_k=_s_k, normalize=False, apply_swa_mask=True, is_causal=False, n_comp=0) _c_out = torch.zeros(_m, _hd, 1, dtype=torch.bfloat16, device='cuda') @@ -327,7 +323,36 @@ if __name__ == '__main__': torch.cuda.synchronize() _o_k = _c_out[:, :, 0].float() / _rs[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30) _cos = torch.nn.functional.cosine_similarity(_o_k.flatten().unsqueeze(0), _ref.flatten().unsqueeze(0).float()).item() - print(f'Baseline (s_k=256, D3 mask, n_comp=0): cos {_cos:.6f} {"PASS" if _cos > 0.99 else "FAIL"}\n') + print(f'Baseline (s_k=128, D3 mask): cos {_cos:.6f} {"PASS" if _cos > 0.99 else "FAIL"}\n') + + # Test 1: D5c with single KV tile (n_comp=64, n_swa=64, s_k=128) + print('=== D5c: single-tile combined KV + sink bias ===\n') + _n_comp = 64; _n_swa = 64; _n_total = 128; _swa_len2 = 40 + _attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda') + _k_comp = torch.randn(_n_comp, _hd, 1, dtype=torch.bfloat16, device='cuda') + _v_comp = torch.randn(_n_comp, _hd, dtype=torch.bfloat16, device='cuda') + _k_swa = torch.randn(_n_swa, _hd, 1, dtype=torch.bfloat16, device='cuda') + _v_swa = torch.randn(_n_swa, _hd, dtype=torch.bfloat16, device='cuda') + _k_comb = torch.cat([_k_comp, _k_swa], dim=0) + _v_comb = torch.cat([_v_comp, _v_swa], dim=0) + _ref2 = reference_combined_attention(_qf, _k_comp[:,:,0], _v_comp, _k_swa[:,:,0], _v_swa, 0.5, _scale, _swa_len2) + _kernel2 = FmhaKernel(head_dim=_hd, s_k=_n_total, normalize=False, apply_swa_mask=True, is_causal=False, n_comp=_n_comp) + _c2 = torch.zeros(_m, _hd, 1, dtype=torch.bfloat16, device='cuda') + _lse2 = torch.zeros(_m, 1, 1, dtype=torch.float32, device='cuda') + _rs2 = torch.zeros(_m, 1, 1, dtype=torch.float32, device='cuda') + _mK2 = _tc(_k_comb); _mV2 = _tc(_v_comb.unsqueeze(-1).contiguous()) + _mC2 = _tc(_c2); _mLSE2 = _tc(_lse2); _mRS2 = _tc(_rs2) + _mSB2 = _tc(_attn_sink) + _comp2 = cute.compile(_kernel2, _mQ, _mK2, _mV2, _mC2, _stream, _mLSE2, swa_len=_swa_len2, sink_bias=_mSB2, row_sums=_mRS2) + _comp2(_mQ, _mK2, _mV2, _mC2, _stream, _mLSE2, swa_len=_swa_len2, sink_bias=_mSB2, row_sums=_mRS2) + torch.cuda.synchronize() + _ok2 = _c2[:, :, 0].float() / _rs2[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30) + _cos2 = torch.nn.functional.cosine_similarity(_ok2.flatten().unsqueeze(0), _ref2.flatten().unsqueeze(0).float()).item() + print(f'D5c single-tile: cos {_cos2:.6f} {"PASS" if _cos2 > 0.99 else "FAIL"}') + if _cos2 < 0.99: + print(f' kernel[0,:4]={_ok2[0,:4].tolist()}') + print(f' ref[0,:4]={_ref2[0,:4].tolist()}') + print(f' row_sum range: {_rs2[:,0,0].min().item():.4f} to {_rs2[:,0,0].max().item():.4f}') test_d5c_combined() test_d5c_with_causal()