From 68df389c93c954e191d499b32f3ae96c28fa3530 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 23:23:05 +0000 Subject: [PATCH] D1.3: Add SMEM-P coordinate diagnostic test --- dsv4/kernels/attention/fmha.py | 4 +- tests/unit/test_d1_3_smem_diag.py | 70 +++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_d1_3_smem_diag.py diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 6b2e517b..446a9340 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -353,10 +353,8 @@ class FmhaKernel: else: # SMEM-P: write P to sP using coordinate-indexed store. # tTMEM_LOADcS contains (m, k) coordinates from identity tensor. - # Shape: ((32,1),4,1,1) — indexed with 4 indices. # Each element is an (m, k) coordinate pair. - # Extract m with .load()[0] and k with .load()[1], - # or use indexing tTMEM_LOADcS[...].value[0/1]. + # rP_bf16 has the same shape/layout as tTMEM_LOADcS (BF16 view of FP32 registers). for j0 in range(32): for j1 in range(4): coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] diff --git a/tests/unit/test_d1_3_smem_diag.py b/tests/unit/test_d1_3_smem_diag.py new file mode 100644 index 00000000..fc6134fd --- /dev/null +++ b/tests/unit/test_d1_3_smem_diag.py @@ -0,0 +1,70 @@ +""" +D1.3 SMEM-P coordinate mapping diagnostic. +Verifies that tTMEM_LOADcS coordinates and rP_bf16 values +correctly map to sP indices for the SMEM-P path. +Runs a minimal kernel that writes P to sP and reads it back. +""" +import torch, math +import cutlass, cutlass.cute as cute, cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16 +from cutlass.utils import LayoutEnum +import cutlass.torch as ct +import cuda.bindings.driver as cuda + +# Test: write known P values to sP using coordinate indexing, then read back +# via the PV MMA's A-operand fragment and verify. + + +def test_smem_p_coords(): + print("=== SMEM-P Coordinate Diagnostic ===\n") + hd = 256; m = 128; s_k = 128; pv_n_tile = 256 + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + + # Simple reference: just compute Q@K^T softmax @ V + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + vf = v.float() + scale = 1.0 / math.sqrt(hd) + attn = qf @ kf.T * scale + ref = torch.softmax(attn, dim=-1) @ vf # (128, 256) + + # Run the kernel with use_smem_p=True + from dsv4.kernels.attention.fmha import FmhaKernel + kern = FmhaKernel(head_dim=hd, s_k=s_k) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + v_tile = v.unsqueeze(-1) # (128, 256, 1) + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + + print('Compiling...', flush=True) + compiled = cute.compile(kern, mQ, mK, mV, mC, stream) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + + out = c[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + max_abs = (out - ref).abs().max().item() + + print(f'hd=256 SMEM-P: cos {cos:.6f} max_abs {max_abs:.4f}') + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') + + # Also check: are the output values in a reasonable range? + print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]') + print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]') + print(f' out has NaN: {torch.isnan(out).any().item()}') + print(f' out has inf: {torch.isinf(out).any().item()}') + + +if __name__ == '__main__': + test_smem_p_coords()