D1: corrected KV merge test with proper normalized output formula

This commit is contained in:
2026-05-24 22:24:27 +00:00
parent c47f648617
commit 49e66fb6e4

View File

@@ -0,0 +1,164 @@
"""
D1: Multi-KV-tile merge using per-row LSE and NORMALIZED outputs.
Correct formula:
O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]
Where O_i_norm = O_i_unnorm / row_sum_i (per-segment normalized output)
And exp(lse_i) = row_sum_i * exp(max(S_i * scale))
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test_multi_kv_merge(hd=64, s_k=256):
m = 128
n_kv_segments = s_k // 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
# FP32 reference (full attention)
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_norm = (attn_exp / attn_sum) @ v.float()
# Run s_k=128 kernel per KV segment
kernel = FmhaKernel(head_dim=hd, s_k=128, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Compile once
k_seg0 = k[:128]
v_tile0 = v[:128, 0:pv_n_tile].contiguous()
v_kernel0 = v_tile0.unsqueeze(-1)
c_tile0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(m, dtype=torch.float32, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k_seg0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg0))
mV = ct.from_dlpack(v_kernel0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel0))
mC = ct.from_dlpack(c_tile0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile0))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
print(f' Compiling (hd={hd}, s_k=128, {n_kv_segments} segments)...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
# Accumulate across KV segments
o_norm_accum = None # (m, hd) normalized output
w_accum = None # (m,) weight = exp(lse)
for seg in range(n_kv_segments):
k_start = seg * 128
k_end = k_start + 128
k_seg = k[k_start:k_end]
v_seg = v[k_start:k_end]
seg_o_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
for nt in range(1): # hd=64 → n_pv_tiles=1
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v_seg[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor.zero_()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
seg_o_unnorm[:, v_start:v_end] = c_tile[:, :, 0].float()
seg_lse = lse_tensor.clone() # (m,) per-row LSE
seg_w = torch.exp(seg_lse) # (m,) = row_sum * exp(max(S * scale))
# Normalize this segment's O
# O_norm = O_unnorm / row_sum
# But we don't have row_sum directly. We have lse = ln(row_sum) + M * ln(2)
# So row_sum = exp(lse - M * ln(2)) = exp(lse) / exp(M * ln(2))
# But M * ln(2) is in the scale_log2 domain...
#
# Actually, the un-normalized O is O_unnorm = P @ V where P = exp(S*scale - row_max)
# And row_sum = sum(P).
# So O_norm = O_unnorm / row_sum.
#
# But row_sum is not directly available. We have lse = ln(row_sum) + row_max * ln(2).
# So row_sum = exp(lse - row_max * ln(2)).
#
# But row_max is in scale_log2 domain: row_max = max(S * scale * log2(e))
# So row_max * ln(2) = max(S * scale)
#
# Therefore: row_sum = exp(lse) / exp(max(S * scale)) = exp(lse) / (2^row_max)
#
# Hmm, we don't have max(S * scale) separately.
# But we don't need it! The merge formula is:
# O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]
# = sum_i [exp(lse_i) * O_i_unnorm / row_sum_i] / sum_i [exp(lse_i)]
# = sum_i [exp(lse_i) * O_i_unnorm / (exp(lse_i) / exp(M_i))] / sum_i [exp(lse_i)]
# = sum_i [exp(M_i) * O_i_unnorm] / sum_i [exp(lse_i)]
#
# So the numerator uses exp(M_i) * O_i_unnorm, where M_i = max(S_i * scale).
# But M_i = row_max_i * ln(2), and we don't have row_max_i separately.
#
# We can derive row_max_i from lse and row_sum:
# But we don't have row_sum either.
#
# Alternative: compute O_norm from O_unnorm using:
# O_norm_i = O_unnorm_i / row_sum_i
# row_sum_i = sum(P_i) = sum(exp(S_i * scale - M_i))
#
# In the kernel, row_sum is computed per-thread. We need to output it.
#
# For now, let me compute row_sum from the reference for testing:
seg_kf = k_seg[:, :, 0].float()
seg_attn = qf @ seg_kf.T * scale
seg_attn_max = seg_attn.max(dim=-1)[0]
seg_row_sum = torch.exp(seg_attn - seg_attn_max.unsqueeze(-1)).sum(dim=-1) # (m,)
seg_o_norm = seg_o_unnorm / seg_row_sum.unsqueeze(-1)
# Merge: O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]
if o_norm_accum is None:
o_norm_accum = seg_w.unsqueeze(-1) * seg_o_norm
w_accum = seg_w
else:
o_norm_accum = o_norm_accum + seg_w.unsqueeze(-1) * seg_o_norm
w_accum = w_accum + seg_w
o_merged = o_norm_accum / w_accum.unsqueeze(-1)
cos = torch.nn.functional.cosine_similarity(
o_merged.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
).item()
print(f' hd={hd}, s_k={s_k} ({n_kv_segments} segments): cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
return cos
def test():
print("=== D1: Multi-KV Merge (corrected formula) ===\n")
test_multi_kv_merge(64, 256)
test_multi_kv_merge(64, 384)
test_multi_kv_merge(64, 512)
test_multi_kv_merge(64, 1024)
if __name__ == '__main__':
test()