""" D1: Debug O rescale at s_k=256 with diagnostic prints. """ import torch, math import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel def test_multi_kv_debug(hd=64, s_k=256): m = 128 n_kv_tiles = s_k // 128 torch.manual_seed(42) q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') # FP32 reference qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] attn_exp = torch.exp(qf @ kf.T * scale - attn_max) attn_sum = attn_exp.sum(dim=-1, keepdim=True) ref_unnorm = attn_exp @ v.float() ref_norm = (attn_exp / attn_sum) @ v.float() # Also compute per-tile references for kt in range(n_kv_tiles): k_start = kt * 128 k_end = k_start + 128 kf_t = k[k_start:k_end, :, 0].float() vf_t = v[k_start:k_end].float() attn_t = qf @ kf_t.T * scale print(f" kt={kt}: K[{k_start}:{k_end}] attn range [{attn_t.min().item():.4f}, {attn_t.max().item():.4f}]") lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False) pv_n_tile = kernel.pv_n_tile n_pv_tiles = kernel.n_pv_tiles print(f" n_kv_tiles={kernel.n_kv_tiles}, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}") # tmem_o0_offset is set in _setup, not __init__ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Compile with first PV tile v_tile = v[:, 0:pv_n_tile].contiguous() v_kernel = v_tile.unsqueeze(-1) c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) print(f' Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) lse_val = None for nt in range(n_pv_tiles): v_start = nt * pv_n_tile v_end = v_start + pv_n_tile v_tile = v[:, v_start:v_end].contiguous() v_kernel = v_tile.unsqueeze(-1) c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') lse_tensor.zero_() mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) compiled(mQ, mK, mV, mC, stream, mLSE) torch.cuda.synchronize() c[:, v_start:v_end, :] = c_tile if nt == 0: lse_val = lse_tensor[0, 0, 0].item() out_unnorm = c[:, :, 0].float() out_norm = out_unnorm / attn_sum # Compare per-row row_cos = [] for i in range(min(8, m)): rc = torch.nn.functional.cosine_similarity( out_unnorm[i].unsqueeze(0), ref_unnorm[i].unsqueeze(0) ).item() row_cos.append(rc) cos_unnorm = torch.nn.functional.cosine_similarity( out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) ).item() print(f" cos_unnorm={cos_unnorm:.6f}") print(f" row 0 cos={row_cos[0]:.6f} row 1 cos={row_cos[1]:.6f}") print(f" out[0,:8]={out_unnorm[0,:8].tolist()}") print(f" ref[0,:8]={ref_unnorm[0,:8].tolist()}") print(f" lse_val={lse_val}, ref_lse={(torch.log(attn_sum[0,0]) + attn_max[0,0]).item()}") return cos_unnorm def test(): print("=== D1: Multi-KV Debug ===\n") test_multi_kv_debug(64, 256) if __name__ == '__main__': test()