diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue
This commit is contained in:
@@ -298,22 +298,18 @@ def test_d5c_with_causal():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# First: baseline test without sink bias (just multi-tile D3 masking)
|
||||
print('=== Baseline: multi-tile attention with D3 mask (no D5c) ===\n')
|
||||
_hd = 64; _m = 128; _s_k = 256; _swa_len = 192 # 192 valid, 64 masked
|
||||
# Test 0: baseline with s_k=128 (single KV tile, no O rescale)
|
||||
print('=== Baseline: single-tile D3 mask (s_k=128, no D5c) ===\n')
|
||||
_hd = 64; _m = 128; _s_k = 128; _swa_len = 64
|
||||
_scale = 1.0 / math.sqrt(_hd)
|
||||
torch.manual_seed(42)
|
||||
_q = torch.randn(_m, _hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
_k = torch.randn(_s_k, _hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
_v = torch.randn(_s_k, _hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Reference
|
||||
_qf = _q[:, :, 0].float(); _kf = _k[:, :, 0].float(); _vf = _v.float()
|
||||
_scores = _qf @ _kf.T * _scale
|
||||
_scores[:, _swa_len:] = float('-inf')
|
||||
_ref = (torch.softmax(_scores, dim=-1) @ _vf).to(torch.bfloat16)
|
||||
|
||||
# Kernel
|
||||
_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
_kernel = FmhaKernel(head_dim=_hd, s_k=_s_k, normalize=False, apply_swa_mask=True, is_causal=False, n_comp=0)
|
||||
_c_out = torch.zeros(_m, _hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
@@ -327,7 +323,36 @@ if __name__ == '__main__':
|
||||
torch.cuda.synchronize()
|
||||
_o_k = _c_out[:, :, 0].float() / _rs[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
||||
_cos = torch.nn.functional.cosine_similarity(_o_k.flatten().unsqueeze(0), _ref.flatten().unsqueeze(0).float()).item()
|
||||
print(f'Baseline (s_k=256, D3 mask, n_comp=0): cos {_cos:.6f} {"PASS" if _cos > 0.99 else "FAIL"}\n')
|
||||
print(f'Baseline (s_k=128, D3 mask): cos {_cos:.6f} {"PASS" if _cos > 0.99 else "FAIL"}\n')
|
||||
|
||||
# Test 1: D5c with single KV tile (n_comp=64, n_swa=64, s_k=128)
|
||||
print('=== D5c: single-tile combined KV + sink bias ===\n')
|
||||
_n_comp = 64; _n_swa = 64; _n_total = 128; _swa_len2 = 40
|
||||
_attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
|
||||
_k_comp = torch.randn(_n_comp, _hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
_v_comp = torch.randn(_n_comp, _hd, dtype=torch.bfloat16, device='cuda')
|
||||
_k_swa = torch.randn(_n_swa, _hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
_v_swa = torch.randn(_n_swa, _hd, dtype=torch.bfloat16, device='cuda')
|
||||
_k_comb = torch.cat([_k_comp, _k_swa], dim=0)
|
||||
_v_comb = torch.cat([_v_comp, _v_swa], dim=0)
|
||||
_ref2 = reference_combined_attention(_qf, _k_comp[:,:,0], _v_comp, _k_swa[:,:,0], _v_swa, 0.5, _scale, _swa_len2)
|
||||
_kernel2 = FmhaKernel(head_dim=_hd, s_k=_n_total, normalize=False, apply_swa_mask=True, is_causal=False, n_comp=_n_comp)
|
||||
_c2 = torch.zeros(_m, _hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
_lse2 = torch.zeros(_m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
_rs2 = torch.zeros(_m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
_mK2 = _tc(_k_comb); _mV2 = _tc(_v_comb.unsqueeze(-1).contiguous())
|
||||
_mC2 = _tc(_c2); _mLSE2 = _tc(_lse2); _mRS2 = _tc(_rs2)
|
||||
_mSB2 = _tc(_attn_sink)
|
||||
_comp2 = cute.compile(_kernel2, _mQ, _mK2, _mV2, _mC2, _stream, _mLSE2, swa_len=_swa_len2, sink_bias=_mSB2, row_sums=_mRS2)
|
||||
_comp2(_mQ, _mK2, _mV2, _mC2, _stream, _mLSE2, swa_len=_swa_len2, sink_bias=_mSB2, row_sums=_mRS2)
|
||||
torch.cuda.synchronize()
|
||||
_ok2 = _c2[:, :, 0].float() / _rs2[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
||||
_cos2 = torch.nn.functional.cosine_similarity(_ok2.flatten().unsqueeze(0), _ref2.flatten().unsqueeze(0).float()).item()
|
||||
print(f'D5c single-tile: cos {_cos2:.6f} {"PASS" if _cos2 > 0.99 else "FAIL"}')
|
||||
if _cos2 < 0.99:
|
||||
print(f' kernel[0,:4]={_ok2[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={_ref2[0,:4].tolist()}')
|
||||
print(f' row_sum range: {_rs2[:,0,0].min().item():.4f} to {_rs2[:,0,0].max().item():.4f}')
|
||||
|
||||
test_d5c_combined()
|
||||
test_d5c_with_causal()
|
||||
|
||||
Reference in New Issue
Block a user