D1: add KV merge test using log-sum-exp (avoids TMEM round-trip)

This commit is contained in:
2026-05-24 22:17:24 +00:00
parent 0f30319e06
commit 02edff5ac7
3 changed files with 336 additions and 0 deletions

View File

@@ -0,0 +1,119 @@
"""
D1: Test multi-KV-tile by running s_k=128 kernel per KV segment and
merging in Python using log-sum-exp (D5 merge formula).
This avoids the broken TMEM round-trip O rescale entirely.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test_multi_kv_merge(hd=64, s_k=256):
m = 128
n_kv_segments = s_k // 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
# FP32 reference (full attention)
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_norm = (attn_exp / attn_sum) @ v.float()
# Run s_k=128 kernel per KV segment and merge using log-sum-exp
kernel = FmhaKernel(head_dim=hd, s_k=128, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Compile once with segment 0's K
k_seg = k[:128]
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
print(f' Compiling (hd={hd}, s_k=128 per segment, {n_kv_segments} segments)...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
# Accumulate across KV segments using log-sum-exp merge
# O_merged = sum_i(exp(lse_i) * O_i) / sum_i(exp(lse_i))
o_accum = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
lse_accum = torch.full((m, 1), float('-inf'), dtype=torch.float32, device='cuda')
for seg in range(n_kv_segments):
k_start = seg * 128
k_end = k_start + 128
k_seg = k[k_start:k_end]
v_seg = v[k_start:k_end]
# Per-segment O and LSE
seg_o = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
seg_lse = torch.zeros(m, 1, dtype=torch.float32, device='cuda')
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v_seg[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor.zero_()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k_seg).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_seg))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
seg_o[:, v_start:v_end] = c_tile[:, :, 0].float()
if nt == 0:
seg_lse[:, 0] = lse_tensor[:, 0, 0].float()
# Merge with accumulator using log-sum-exp
# O_new = (exp(lse_old) * O_old + exp(lse_new) * O_new) / (exp(lse_old) + exp(lse_new))
# lse_new = ln(exp(lse_old) + exp(lse_new))
e_old = torch.exp(lse_accum) # (m, 1)
e_new = torch.exp(seg_lse) # (m, 1)
e_sum = e_old + e_new
o_accum = (e_old * o_accum + e_new * seg_o) / e_sum
lse_accum = torch.log(e_sum)
cos = torch.nn.functional.cosine_similarity(
o_accum.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
).item()
print(f' hd={hd}, s_k={s_k} ({n_kv_segments} segments): cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
return cos
def test():
print("=== D1: Multi-KV Merge via Log-Sum-Exp (no TMEM round-trip) ===\n")
test_multi_kv_merge(64, 256)
test_multi_kv_merge(64, 384)
test_multi_kv_merge(64, 512)
test_multi_kv_merge(64, 1024)
test_multi_kv_merge(128, 256)
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,115 @@
"""
D1: Minimal O rescale test with just s_k=256 at hd=64.
Tests the exact same thing as test_d1_multi_kv but simpler.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
hd = 64
s_k = 256
m = 128
n_kv_tiles = s_k // 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# FP32 reference (full attention)
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_unnorm = attn_exp @ v.float()
ref_norm = (attn_exp / attn_sum) @ v.float()
# Per-tile references for debugging
# Tile 0 only
kf0 = k[:128, :, 0].float()
attn0 = qf @ kf0.T * scale
attn_max0 = attn0.max(dim=-1, keepdim=True)[0]
attn_exp0 = torch.exp(attn0 - attn_max0)
ref0 = attn_exp0 @ v[:128].float()
# Tile 1 only (with rescale from tile 0's max)
kf1 = k[128:, :, 0].float()
attn1 = qf @ kf1.T * scale
new_max = torch.max(attn_max0, (qf @ kf1.T * scale).max(dim=-1, keepdim=True)[0])
acc_scale = torch.exp(attn_max0 - new_max)
attn_exp1 = torch.exp(attn1 - new_max)
ref_rescaled = acc_scale * ref0 + attn_exp1 @ v[128:].float()
print(f" Tile-0 only O[0,:4] = {ref0[0,:4].tolist()}")
print(f" Rescaled O[0,:4] = {ref_rescaled[0,:4].tolist()}")
print(f" Full ref O[0,:4] = {ref_unnorm[0,:4].tolist()}")
print(f" acc_scale range = [{acc_scale.min().item():.4f}, {acc_scale.max().item():.4f}]")
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
print(f' Compiling (n_kv_tiles={n_kv_tiles})...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
lse_val = None
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor.zero_()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
c[:, v_start:v_end, :] = c_tile
if nt == 0:
lse_val = lse_tensor[0, 0, 0].item()
out_unnorm = c[:, :, 0].float()
out_norm = out_unnorm / attn_sum
cos_unnorm = torch.nn.functional.cosine_similarity(
out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
).item()
cos_norm = torch.nn.functional.cosine_similarity(
out_norm.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
).item()
print(f" cos_unnorm={cos_unnorm:.6f} cos_norm={cos_norm:.6f}")
print(f" out[0,:4]={out_unnorm[0,:4].tolist()}")
print(f" lse_val={lse_val}")
print(f" {'PASS' if cos_unnorm >= 0.99 else 'FAIL'}")
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,102 @@
"""
D1: Test TMEM round-trip on O in isolation.
Runs the kernel with s_k=128 (1 KV tile, no rescale needed).
Then manually does a load-modify-store round-trip on O in TMEM
using the correction_rescale atoms.
If the round-trip corrupts data, we know the atoms are broken.
If it preserves data, the bug is elsewhere.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
hd = 64
s_k = 128 # 1 KV tile, no rescale needed
m = 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# FP32 reference
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_unnorm = attn_exp @ v.float()
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
# Test 1: s_k=128 baseline (no rescale) — should be PASS
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
print(f'Test 1: s_k=128 baseline (no rescale)', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
out1 = c_tile[:, :, 0].float()
cos1 = torch.nn.functional.cosine_similarity(
out1.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
).item()
print(f' cos_unnorm={cos1:.6f} {"PASS" if cos1 >= 0.99 else "FAIL"}')
# Test 2: s_k=256 with rescale — this is the failing test
s_k2 = 256
k2 = torch.randn(s_k2, hd, 1, dtype=torch.bfloat16, device='cuda')
v2 = torch.randn(s_k2, hd, dtype=torch.bfloat16, device='cuda')
c2 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
kf2 = k2[:, :, 0].float()
attn_max2 = (qf @ kf2.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp2 = torch.exp(qf @ kf2.T * scale - attn_max2)
attn_sum2 = attn_exp2.sum(dim=-1, keepdim=True)
ref_unnorm2 = attn_exp2 @ v2.float()
kernel2 = FmhaKernel(head_dim=hd, s_k=s_k2, use_smem_p=False, normalize=False)
lse_tensor2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
v_tile2 = v2[:, 0:pv_n_tile].contiguous()
v_kernel2 = v_tile2.unsqueeze(-1)
c_tile2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2))
mV2 = ct.from_dlpack(v_kernel2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel2))
mC2 = ct.from_dlpack(c_tile2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile2))
mLSE2 = ct.from_dlpack(lse_tensor2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor2))
print(f'Test 2: s_k=256 with O rescale', flush=True)
compiled2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2)
compiled2(mQ, mK2, mV2, mC2, stream, mLSE2)
torch.cuda.synchronize()
out2 = c_tile2[:, :, 0].float()
cos2 = torch.nn.functional.cosine_similarity(
out2.flatten().unsqueeze(0), ref_unnorm2.flatten().unsqueeze(0)
).item()
print(f' cos_unnorm={cos2:.6f} {"PASS" if cos2 >= 0.99 else "FAIL"}')
if __name__ == '__main__':
test()