diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue

This commit is contained in:
2026-05-26 15:09:49 +00:00
parent e5381b7312
commit fc0f4bcf23

View File

@@ -298,22 +298,18 @@ def test_d5c_with_causal():
if __name__ == '__main__':
# First: baseline test without sink bias (just multi-tile D3 masking)
print('=== Baseline: multi-tile attention with D3 mask (no D5c) ===\n')
_hd = 64; _m = 128; _s_k = 256; _swa_len = 192 # 192 valid, 64 masked
# Test 0: baseline with s_k=128 (single KV tile, no O rescale)
print('=== Baseline: single-tile D3 mask (s_k=128, no D5c) ===\n')
_hd = 64; _m = 128; _s_k = 128; _swa_len = 64
_scale = 1.0 / math.sqrt(_hd)
torch.manual_seed(42)
_q = torch.randn(_m, _hd, 1, dtype=torch.bfloat16, device='cuda')
_k = torch.randn(_s_k, _hd, 1, dtype=torch.bfloat16, device='cuda')
_v = torch.randn(_s_k, _hd, dtype=torch.bfloat16, device='cuda')
# Reference
_qf = _q[:, :, 0].float(); _kf = _k[:, :, 0].float(); _vf = _v.float()
_scores = _qf @ _kf.T * _scale
_scores[:, _swa_len:] = float('-inf')
_ref = (torch.softmax(_scores, dim=-1) @ _vf).to(torch.bfloat16)
# Kernel
_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
_kernel = FmhaKernel(head_dim=_hd, s_k=_s_k, normalize=False, apply_swa_mask=True, is_causal=False, n_comp=0)
_c_out = torch.zeros(_m, _hd, 1, dtype=torch.bfloat16, device='cuda')
@@ -327,7 +323,36 @@ if __name__ == '__main__':
torch.cuda.synchronize()
_o_k = _c_out[:, :, 0].float() / _rs[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
_cos = torch.nn.functional.cosine_similarity(_o_k.flatten().unsqueeze(0), _ref.flatten().unsqueeze(0).float()).item()
print(f'Baseline (s_k=256, D3 mask, n_comp=0): cos {_cos:.6f} {"PASS" if _cos > 0.99 else "FAIL"}\n')
print(f'Baseline (s_k=128, D3 mask): cos {_cos:.6f} {"PASS" if _cos > 0.99 else "FAIL"}\n')
# Test 1: D5c with single KV tile (n_comp=64, n_swa=64, s_k=128)
print('=== D5c: single-tile combined KV + sink bias ===\n')
_n_comp = 64; _n_swa = 64; _n_total = 128; _swa_len2 = 40
_attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
_k_comp = torch.randn(_n_comp, _hd, 1, dtype=torch.bfloat16, device='cuda')
_v_comp = torch.randn(_n_comp, _hd, dtype=torch.bfloat16, device='cuda')
_k_swa = torch.randn(_n_swa, _hd, 1, dtype=torch.bfloat16, device='cuda')
_v_swa = torch.randn(_n_swa, _hd, dtype=torch.bfloat16, device='cuda')
_k_comb = torch.cat([_k_comp, _k_swa], dim=0)
_v_comb = torch.cat([_v_comp, _v_swa], dim=0)
_ref2 = reference_combined_attention(_qf, _k_comp[:,:,0], _v_comp, _k_swa[:,:,0], _v_swa, 0.5, _scale, _swa_len2)
_kernel2 = FmhaKernel(head_dim=_hd, s_k=_n_total, normalize=False, apply_swa_mask=True, is_causal=False, n_comp=_n_comp)
_c2 = torch.zeros(_m, _hd, 1, dtype=torch.bfloat16, device='cuda')
_lse2 = torch.zeros(_m, 1, 1, dtype=torch.float32, device='cuda')
_rs2 = torch.zeros(_m, 1, 1, dtype=torch.float32, device='cuda')
_mK2 = _tc(_k_comb); _mV2 = _tc(_v_comb.unsqueeze(-1).contiguous())
_mC2 = _tc(_c2); _mLSE2 = _tc(_lse2); _mRS2 = _tc(_rs2)
_mSB2 = _tc(_attn_sink)
_comp2 = cute.compile(_kernel2, _mQ, _mK2, _mV2, _mC2, _stream, _mLSE2, swa_len=_swa_len2, sink_bias=_mSB2, row_sums=_mRS2)
_comp2(_mQ, _mK2, _mV2, _mC2, _stream, _mLSE2, swa_len=_swa_len2, sink_bias=_mSB2, row_sums=_mRS2)
torch.cuda.synchronize()
_ok2 = _c2[:, :, 0].float() / _rs2[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
_cos2 = torch.nn.functional.cosine_similarity(_ok2.flatten().unsqueeze(0), _ref2.flatten().unsqueeze(0).float()).item()
print(f'D5c single-tile: cos {_cos2:.6f} {"PASS" if _cos2 > 0.99 else "FAIL"}')
if _cos2 < 0.99:
print(f' kernel[0,:4]={_ok2[0,:4].tolist()}')
print(f' ref[0,:4]={_ref2[0,:4].tolist()}')
print(f' row_sum range: {_rs2[:,0,0].min().item():.4f} to {_rs2[:,0,0].max().item():.4f}')
test_d5c_combined()
test_d5c_with_causal()