From c042fcf6c7fc58d5e9f7bc52d2a99b8039d71e9a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:22:23 +0000 Subject: [PATCH] D1: Add diagnostic test (TMEM-P vs SMEM-P at various hd) --- tests/unit/test_d1_diag2.py | 146 ++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 tests/unit/test_d1_diag2.py diff --git a/tests/unit/test_d1_diag2.py b/tests/unit/test_d1_diag2.py new file mode 100644 index 00000000..1c22a62b --- /dev/null +++ b/tests/unit/test_d1_diag2.py @@ -0,0 +1,146 @@ +""" +Quick D1 diagnostic: test TMEM-P path (use_smem_p=False) at various head dims. +The SMEM-P path (use_smem_p=True, hd>64) has coordinate mapping issues. +This test forces TMEM-P to verify the core pipeline works. +""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test_tmem_p(hd, n_kv=128): + m = 128 + torch.manual_seed(42) + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_norm = (attn_exp / attn_sum) @ v.float() + ref_unnorm = attn_exp @ v.float() + + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + # Force TMEM-P + kernel = FmhaKernel(head_dim=hd, s_k=n_kv, use_smem_p=False) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + print(f'hd={hd} TMEM-P: Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + + lse_val = None + for nt in range(n_pv_tiles): + vs, ve = nt * pv_n_tile, (nt + 1) * pv_n_tile + v_t = v[:, vs:ve].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor.zero_() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + c[:, vs:ve, :] = c_tile + if nt == 0: + lse_val = lse_tensor[0, 0, 0].item() + + out = c[:, :, 0].float() + out_norm = out / attn_sum + cos_unnorm = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)).item() + cos_norm = torch.nn.functional.cosine_similarity(out_norm.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item() + status = "PASS" if cos_unnorm >= 0.99 else "FAIL" + print(f'hd={hd} TMEM-P: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f} lse {lse_val:.6f} {status}') + return cos_unnorm + + +def test_smem_p(hd, n_kv=128): + m = 128 + torch.manual_seed(42) + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_unnorm = attn_exp @ v.float() + + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + kernel = FmhaKernel(head_dim=hd, s_k=n_kv, use_smem_p=True) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + print(f'hd={hd} SMEM-P: Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + + lse_val = None + for nt in range(n_pv_tiles): + vs, ve = nt * pv_n_tile, (nt + 1) * pv_n_tile + v_t = v[:, vs:ve].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor.zero_() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + c[:, vs:ve, :] = c_tile + if nt == 0: + lse_val = lse_tensor[0, 0, 0].item() + + out = c[:, :, 0].float() + cos_unnorm = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)).item() + status = "PASS" if cos_unnorm >= 0.99 else "FAIL" + print(f'hd={hd} SMEM-P: cos_unnorm {cos_unnorm:.6f} lse {lse_val:.6f} {status}') + if cos_unnorm < 0.97: + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref_unnorm[0,:4].tolist()}') + return cos_unnorm + + +if __name__ == '__main__': + print("=== D1 Diagnostic ===\n") + + # TMEM-P path (proven at hd=64) + print("--- TMEM-P (force use_smem_p=False) ---") + test_tmem_p(64) + test_tmem_p(128) + test_tmem_p(256) + + # SMEM-P path (for hd>64) + print("\n--- SMEM-P (use_smem_p=True) ---") + test_smem_p(128) + test_smem_p(256)