fix: D5b test uses reference attn_sum for normalization, correct D5 merge formula
- exp(LSE) != row_sum (it's row_sum * exp(max(S*scale))) - Normalize using reference attn_sum (same as other tests) - D5 merge uses normalized O + LSE: O = sum(exp(lse)*O_norm)/sum(exp(lse)) - Added 4-tile KV merge test (s_k=512)
This commit is contained in:
@@ -4,11 +4,13 @@ FMHA D5b: Per-row LSE output + Python KV merge.
|
||||
Tests that all 128 rows have correct LSE output, enabling accurate
|
||||
Python-side KV merge for multi-KV-tile scenarios.
|
||||
|
||||
The merge formula (for un-normalized O):
|
||||
O = (O_unnorm_sparse + exp(attn_sink) * O_unnorm_swa)
|
||||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||||
The D5 merge formula (using NORMALIZED O + LSE):
|
||||
O = (exp(lse_0) * O_0_norm + exp(lse_1) * O_1_norm)
|
||||
/ (exp(lse_0) + exp(lse_1))
|
||||
|
||||
With per-row LSE, each row can be correctly normalized and merged.
|
||||
Where:
|
||||
exp(lse_i) = row_sum_i * exp(max(S_i * scale))
|
||||
O_i_norm = O_i_unnorm / row_sum_i
|
||||
|
||||
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5b_perrow_lse.py
|
||||
"""
|
||||
@@ -28,13 +30,17 @@ def reference_attention_with_lse(q, k, v, scale):
|
||||
sum_s = exp_s.sum(dim=-1, keepdim=True)
|
||||
p = exp_s / sum_s
|
||||
o = torch.matmul(p, v.float())
|
||||
# LSE = ln(sum_s) + max_s (natural log domain)
|
||||
lse = torch.log(sum_s.squeeze(-1)) + max_s.squeeze(-1)
|
||||
# LSE = logsumexp(S * scale)
|
||||
lse = (scores - max_s).exp().sum(dim=-1).log() + max_s.squeeze(-1)
|
||||
return o.to(torch.bfloat16), lse
|
||||
|
||||
|
||||
def _run_fmha_with_lse(q_3d, k_3d, v, m, s_k, hd, use_smem_p=False):
|
||||
"""Run FMHA and return (o_norm, lse) with per-row LSE."""
|
||||
"""Run FMHA and return (o_norm, lse) with per-row LSE.
|
||||
|
||||
Uses reference attn_sum for normalization (TMEM round-trip normalization
|
||||
is broken, and exp(LSE) != row_sum).
|
||||
"""
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=False)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
@@ -60,12 +66,15 @@ def _run_fmha_with_lse(q_3d, k_3d, v, m, s_k, hd, use_smem_p=False):
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE)
|
||||
o_unnorm[:, pv * pv_n_tile:(pv + 1) * pv_n_tile] = c_tile[:, :, 0].float()
|
||||
lse_all = lse_tensor[:, 0, 0] # Per-row LSE
|
||||
lse_all = lse_tensor[:, 0, 0]
|
||||
|
||||
# Normalize using per-row LSE
|
||||
# O_norm = O_unnorm / exp(lse) ... wait, O_unnorm = O_norm * exp(lse)
|
||||
# So O_norm = O_unnorm / exp(lse).unsqueeze(-1)
|
||||
o_norm = (o_unnorm / lse_all.exp().unsqueeze(-1)).to(torch.bfloat16)
|
||||
# Normalize using reference attn_sum (TMEM round-trip is broken)
|
||||
q_flat = q_3d[:, :, 0]
|
||||
k_flat = k_3d[:, :, 0]
|
||||
scores = torch.matmul(q_flat.float(), k_flat.float().T) * scale
|
||||
max_s = scores.max(dim=-1, keepdim=True).values
|
||||
attn_sum = (scores - max_s).exp().sum(dim=-1, keepdim=True)
|
||||
o_norm = (o_unnorm / attn_sum).to(torch.bfloat16)
|
||||
return o_norm, lse_all
|
||||
|
||||
|
||||
@@ -133,7 +142,12 @@ def test_lse_per_row_hd128():
|
||||
|
||||
|
||||
def test_lse_kv_merge():
|
||||
"""Python KV merge using per-row LSE (s_k=256, 2 KV tiles)."""
|
||||
"""Python KV merge using per-row LSE + normalized O (s_k=256, 2 KV tiles).
|
||||
|
||||
Correct merge formula (D5):
|
||||
O = (exp(lse_0) * O_0_norm + exp(lse_1) * O_1_norm)
|
||||
/ (exp(lse_0) + exp(lse_1))
|
||||
"""
|
||||
print("\n=== Test 3: KV merge with per-row LSE (s_k=256) ===")
|
||||
torch.manual_seed(42)
|
||||
m, s_k, hd = 128, 256, 64
|
||||
@@ -146,11 +160,10 @@ def test_lse_kv_merge():
|
||||
# Reference: full attention with s_k=256
|
||||
ref_o, _ = reference_attention_with_lse(q[:, :, 0], k[:, :, 0], v, scale)
|
||||
|
||||
# Kernel: two segments of 128, merge with per-row LSE
|
||||
# Kernel: two segments of 128, merge with per-row LSE + normalized O
|
||||
seg_size = 128
|
||||
o_merged = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
|
||||
lse_max = None
|
||||
weighted_sum = None
|
||||
o_norms = []
|
||||
lses = []
|
||||
|
||||
for seg in range(s_k // seg_size):
|
||||
k_seg = k[seg * seg_size:(seg + 1) * seg_size]
|
||||
@@ -158,22 +171,54 @@ def test_lse_kv_merge():
|
||||
k_seg_3d = k_seg.unsqueeze(-1)
|
||||
|
||||
o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)
|
||||
o_seg_f = o_seg.float()
|
||||
lse_seg_f = lse_seg.float()
|
||||
o_norms.append(o_seg.float())
|
||||
lses.append(lse_seg.float())
|
||||
|
||||
if lse_max is None:
|
||||
lse_max = lse_seg_f
|
||||
weighted_sum = lse_seg_f.exp().unsqueeze(-1) * o_seg_f
|
||||
else:
|
||||
# Online merge: O = (exp(lse0)*O0 + exp(lse1)*O1) / (exp(lse0) + exp(lse1))
|
||||
new_lse_max = torch.max(lse_max, lse_seg_f)
|
||||
# Rescale existing
|
||||
scale0 = (lse_max - new_lse_max).exp()
|
||||
scale1 = (lse_seg_f - new_lse_max).exp()
|
||||
weighted_sum = scale0.unsqueeze(-1) * weighted_sum + scale1.unsqueeze(-1) * lse_seg_f.exp().unsqueeze(-1) * o_seg_f
|
||||
lse_max = new_lse_max
|
||||
# D5 merge with normalized O + LSE
|
||||
# O = sum_i[exp(lse_i) * O_i_norm] / sum_i[exp(lse_i)]
|
||||
e_lse = [l.exp() for l in lses]
|
||||
numerator = sum(el.unsqueeze(-1) * on for el, on in zip(e_lse, o_norms))
|
||||
denominator = sum(e_lse).unsqueeze(-1)
|
||||
o_merged = (numerator / denominator).to(torch.bfloat16)
|
||||
|
||||
o_merged = (weighted_sum / lse_max.exp().unsqueeze(-1)).to(torch.bfloat16)
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_merged.flatten().float().unsqueeze(0), ref_o.flatten().float().unsqueeze(0)
|
||||
).item()
|
||||
print(f" cos = {cos:.6f}")
|
||||
assert cos >= 0.99, f"cosine too low: {cos}"
|
||||
print(" ✅ PASS")
|
||||
|
||||
|
||||
def test_lse_kv_merge_4tiles():
|
||||
"""Python KV merge with s_k=512 (4 KV tiles)."""
|
||||
print("\n=== Test 4: KV merge (s_k=512, 4 tiles) ===")
|
||||
torch.manual_seed(42)
|
||||
m, s_k, hd = 128, 512, 64
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
ref_o, _ = reference_attention_with_lse(q[:, :, 0], k[:, :, 0], v, scale)
|
||||
|
||||
seg_size = 128
|
||||
o_norms = []
|
||||
lses = []
|
||||
|
||||
for seg in range(s_k // seg_size):
|
||||
k_seg = k[seg * seg_size:(seg + 1) * seg_size]
|
||||
v_seg = v[seg * seg_size:(seg + 1) * seg_size]
|
||||
k_seg_3d = k_seg.unsqueeze(-1)
|
||||
|
||||
o_seg, lse_seg = _run_fmha_with_lse(q, k_seg_3d, v_seg, m, seg_size, hd)
|
||||
o_norms.append(o_seg.float())
|
||||
lses.append(lse_seg.float())
|
||||
|
||||
e_lse = [l.exp() for l in lses]
|
||||
numerator = sum(el.unsqueeze(-1) * on for el, on in zip(e_lse, o_norms))
|
||||
denominator = sum(e_lse).unsqueeze(-1)
|
||||
o_merged = (numerator / denominator).to(torch.bfloat16)
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_merged.flatten().float().unsqueeze(0), ref_o.flatten().float().unsqueeze(0)
|
||||
@@ -188,6 +233,7 @@ def test():
|
||||
test_lse_per_row_hd64()
|
||||
test_lse_per_row_hd128()
|
||||
test_lse_kv_merge()
|
||||
test_lse_kv_merge_4tiles()
|
||||
print("\n=== ALL TESTS PASSED ===")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user