""" D1.5 Debug: Test s_k=256 in-kernel rescale with diagnostics. Minimal test to isolate the TMEM round-trip vs barrier issue. """ import torch, math import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel def reference_attention(q, k, v, scale): qf = q.float(); kf = k.float() attn = qf @ kf.T * scale attn_max = attn.max(dim=-1, keepdim=True)[0] attn_exp = torch.exp(attn - attn_max) return attn_exp @ v.float() def test(): hd = 64; m = 128; scale = 1.0 / math.sqrt(hd) torch.manual_seed(42) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') # Test 1: s_k=128 baseline k1 = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda') v1 = torch.randn(128, hd, dtype=torch.bfloat16, device='cuda') c1 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') kernel1 = FmhaKernel(head_dim=hd, s_k=128, use_smem_p=False, normalize=False) pv_n_tile = kernel1.pv_n_tile v1t = v1[:, 0:pv_n_tile].contiguous().unsqueeze(-1) mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK1 = ct.from_dlpack(k1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k1)) mV1 = ct.from_dlpack(v1t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v1t)) mC1 = ct.from_dlpack(c1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c1)) mLSE1 = ct.from_dlpack(lse1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse1)) mRS1 = ct.from_dlpack(rs1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs1)) compiled1 = cute.compile(kernel1, mQ, mK1, mV1, mC1, stream, mLSE1, row_sums=mRS1) compiled1(mQ, mK1, mV1, mC1, stream, mLSE1, row_sums=mRS1) torch.cuda.synchronize() ref1 = reference_attention(q[:, :, 0], k1[:, :, 0], v1, scale) cos1 = torch.nn.functional.cosine_similarity(c1[:, :, 0].float().flatten().unsqueeze(0), ref1.flatten().unsqueeze(0)).item() print(f's_k=128 baseline: cos={cos1:.6f} {"PASS" if cos1 >= 0.999 else "FAIL"}', flush=True) # Test 2: s_k=256 with in-kernel rescale k2 = torch.randn(256, hd, 1, dtype=torch.bfloat16, device='cuda') v2 = torch.randn(256, hd, dtype=torch.bfloat16, device='cuda') c2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') kernel2 = FmhaKernel(head_dim=hd, s_k=256, use_smem_p=False, normalize=False) v2t = v2[:, 0:pv_n_tile].contiguous().unsqueeze(-1) mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2)) mV2 = ct.from_dlpack(v2t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2t)) mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2)) mLSE2 = ct.from_dlpack(lse2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse2)) mRS2 = ct.from_dlpack(rs2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs2)) compiled2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2) compiled2(mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2) torch.cuda.synchronize() ref2 = reference_attention(q[:, :, 0], k2[:, :, 0], v2, scale) out2 = c2[:, :, 0].float() cos2 = torch.nn.functional.cosine_similarity(out2.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)).item() # Per-element stats diff2 = (out2 - ref2).abs() max_rel_err = (diff2 / ref2.abs().clamp(min=1e-6)).max().item() print(f's_k=256 in-kernel rescale: cos={cos2:.6f} max_rel_err={max_rel_err:.4f} {"PASS" if cos2 >= 0.999 else "FAIL"}', flush=True) # Print LSE and row_sums for debugging print(f' LSE range: [{lse2[:, 0, 0].min().item():.4f}, {lse2[:, 0, 0].max().item():.4f}]', flush=True) print(f' row_sums range: [{rs2[:, 0, 0].min().item():.4f}, {rs2[:, 0, 0].max().item():.4f}]', flush=True) # Test 3: Python KV merge for comparison c_s0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') rs_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') v2_0 = v2[:128, 0:pv_n_tile].contiguous().unsqueeze(-1) mK2_0 = ct.from_dlpack(k2[:128]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2[:128])) mV2_0 = ct.from_dlpack(v2_0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_0)) mC_s0 = ct.from_dlpack(c_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_s0)) mLSE_s0 = ct.from_dlpack(lse_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_s0)) mRS_s0 = ct.from_dlpack(rs_s0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs_s0)) compiled_s0 = cute.compile(kernel1, mQ, mK2_0, mV2_0, mC_s0, stream, mLSE_s0, row_sums=mRS_s0) compiled_s0(mQ, mK2_0, mV2_0, mC_s0, stream, mLSE_s0, row_sums=mRS_s0) c_s1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') rs_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') v2_1 = v2[128:, 0:pv_n_tile].contiguous().unsqueeze(-1) mK2_1 = ct.from_dlpack(k2[128:]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2[128:])) mV2_1 = ct.from_dlpack(v2_1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_1)) mC_s1 = ct.from_dlpack(c_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_s1)) mLSE_s1 = ct.from_dlpack(lse_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_s1)) mRS_s1 = ct.from_dlpack(rs_s1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs_s1)) compiled_s1 = cute.compile(kernel1, mQ, mK2_1, mV2_1, mC_s1, stream, mLSE_s1, row_sums=mRS_s1) compiled_s1(mQ, mK2_1, mV2_1, mC_s1, stream, mLSE_s1, row_sums=mRS_s1) torch.cuda.synchronize() o0 = c_s0[:, :, 0].float(); o1 = c_s1[:, :, 0].float() r0 = rs_s0[:, 0, 0].float(); r1 = rs_s1[:, 0, 0].float() l0 = lse_s0[:, 0, 0].float(); l1 = lse_s1[:, 0, 0].float() o0_norm = o0 / r0.unsqueeze(1).clamp(min=1e-30) o1_norm = o1 / r1.unsqueeze(1).clamp(min=1e-30) w0 = torch.exp(l0).unsqueeze(1); w1 = torch.exp(l1).unsqueeze(1) oracle = (w0 * o0_norm + w1 * o1_norm) / (w0 + w1) cos_oracle = torch.nn.functional.cosine_similarity(oracle.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)).item() print(f'Python KV merge: cos={cos_oracle:.6f}', flush=True) if __name__ == '__main__': test()