117 lines
6.6 KiB
Python
117 lines
6.6 KiB
Python
"""
|
|
D1.5 Debug: Test s_k=256 in-kernel rescale with diagnostics.
|
|
Minimal test to isolate the TMEM round-trip vs barrier issue.
|
|
"""
|
|
import torch, math
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
|
|
def reference_attention(q, k, v, scale):
|
|
qf = q.float(); kf = k.float()
|
|
attn = qf @ kf.T * scale
|
|
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
|
attn_exp = torch.exp(attn - attn_max)
|
|
return attn_exp @ v.float()
|
|
|
|
|
|
def test():
|
|
hd = 64; m = 128; scale = 1.0 / math.sqrt(hd)
|
|
torch.manual_seed(42)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# Test 1: s_k=128 baseline
|
|
k1 = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v1 = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda')
|
|
c1 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
kernel1 = FmhaKernel(head_dim=hd, s_k=128, use_smem_p=False, normalize=False)
|
|
pv_n_tile = kernel1.pv_n_tile
|
|
v1t = v1[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK1 = ct.from_dlpack(k1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k1))
|
|
mV1 = ct.from_dlpack(v1t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v1t))
|
|
mC1 = ct.from_dlpack(c1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c1))
|
|
mLSE1 = ct.from_dlpack(lse1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse1))
|
|
mRS1 = ct.from_dlpack(rs1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs1))
|
|
compiled1 = cute.compile(kernel1, mQ, mK1, mV1, mC1, stream, mLSE1, row_sums=mRS1)
|
|
compiled1(mQ, mK1, mV1, mC1, stream, mLSE1, row_sums=mRS1)
|
|
torch.cuda.synchronize()
|
|
ref1 = reference_attention(q[:, :, 0], k1[:, :, 0], v1, scale)
|
|
cos1 = torch.nn.functional.cosine_similarity(c1[:, :, 0].float().flatten().unsqueeze(0), ref1.flatten().unsqueeze(0)).item()
|
|
print(f's_k=128 baseline: cos={cos1:.6f} {"PASS" if cos1 >= 0.999 else "FAIL"}', flush=True)
|
|
|
|
# Test 2: s_k=256 with in-kernel rescale
|
|
k2 = torch.randn(256, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v2 = torch.randn(256, hd, dtype=torch.bfloat16, device='cuda')
|
|
c2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
kernel2 = FmhaKernel(head_dim=hd, s_k=256, use_smem_p=False, normalize=False)
|
|
v2t = v2[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
|
mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2))
|
|
mV2 = ct.from_dlpack(v2t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2t))
|
|
mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2))
|
|
mLSE2 = ct.from_dlpack(lse2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse2))
|
|
mRS2 = ct.from_dlpack(rs2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs2))
|
|
compiled2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
|
|
compiled2(mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
|
|
torch.cuda.synchronize()
|
|
|
|
ref2 = reference_attention(q[:, :, 0], k2[:, :, 0], v2, scale)
|
|
out2 = c2[:, :, 0].float()
|
|
cos2 = torch.nn.functional.cosine_similarity(out2.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)).item()
|
|
|
|
# Per-element stats
|
|
diff2 = (out2 - ref2).abs()
|
|
max_rel_err = (diff2 / ref2.abs().clamp(min=1e-6)).max().item()
|
|
print(f's_k=256 in-kernel rescale: cos={cos2:.6f} max_rel_err={max_rel_err:.4f} {"PASS" if cos2 >= 0.999 else "FAIL"}', flush=True)
|
|
|
|
# Print LSE and row_sums for debugging
|
|
print(f' LSE range: [{lse2[:, 0, 0].min().item():.4f}, {lse2[:, 0, 0].max().item():.4f}]', flush=True)
|
|
print(f' row_sums range: [{rs2[:, 0, 0].min().item():.4f}, {rs2[:, 0, 0].max().item():.4f}]', flush=True)
|
|
|
|
# Test 3: Python KV merge for comparison
|
|
c_s0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
v2_0 = v2[:128, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
|
mK2_0 = ct.from_dlpack(k2[:128]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2[:128]))
|
|
mV2_0 = ct.from_dlpack(v2_0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_0))
|
|
mC_s0 = ct.from_dlpack(c_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_s0))
|
|
mLSE_s0 = ct.from_dlpack(lse_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_s0))
|
|
mRS_s0 = ct.from_dlpack(rs_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs_s0))
|
|
compiled_s0 = cute.compile(kernel1, mQ, mK2_0, mV2_0, mC_s0, stream, mLSE_s0, row_sums=mRS_s0)
|
|
compiled_s0(mQ, mK2_0, mV2_0, mC_s0, stream, mLSE_s0, row_sums=mRS_s0)
|
|
|
|
c_s1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
v2_1 = v2[128:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
|
mK2_1 = ct.from_dlpack(k2[128:]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2[128:]))
|
|
mV2_1 = ct.from_dlpack(v2_1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_1))
|
|
mC_s1 = ct.from_dlpack(c_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_s1))
|
|
mLSE_s1 = ct.from_dlpack(lse_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_s1))
|
|
mRS_s1 = ct.from_dlpack(rs_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs_s1))
|
|
compiled_s1 = cute.compile(kernel1, mQ, mK2_1, mV2_1, mC_s1, stream, mLSE_s1, row_sums=mRS_s1)
|
|
compiled_s1(mQ, mK2_1, mV2_1, mC_s1, stream, mLSE_s1, row_sums=mRS_s1)
|
|
torch.cuda.synchronize()
|
|
|
|
o0 = c_s0[:, :, 0].float(); o1 = c_s1[:, :, 0].float()
|
|
r0 = rs_s0[:, 0, 0].float(); r1 = rs_s1[:, 0, 0].float()
|
|
l0 = lse_s0[:, 0, 0].float(); l1 = lse_s1[:, 0, 0].float()
|
|
o0_norm = o0 / r0.unsqueeze(1).clamp(min=1e-30)
|
|
o1_norm = o1 / r1.unsqueeze(1).clamp(min=1e-30)
|
|
w0 = torch.exp(l0).unsqueeze(1); w1 = torch.exp(l1).unsqueeze(1)
|
|
oracle = (w0 * o0_norm + w1 * o1_norm) / (w0 + w1)
|
|
cos_oracle = torch.nn.functional.cosine_similarity(oracle.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)).item()
|
|
print(f'Python KV merge: cos={cos_oracle:.6f}', flush=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|