D1.3: Add SMEM-P coordinate diagnostic test

This commit is contained in:
2026-05-23 23:23:05 +00:00
parent 27123f82ba
commit 68df389c93
2 changed files with 71 additions and 3 deletions

View File

@@ -353,10 +353,8 @@ class FmhaKernel:
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# tTMEM_LOADcS contains (m, k) coordinates from identity tensor.
# Shape: ((32,1),4,1,1) — indexed with 4 indices.
# Each element is an (m, k) coordinate pair.
# Extract m with .load()[0] and k with .load()[1],
# or use indexing tTMEM_LOADcS[...].value[0/1].
# rP_bf16 has the same shape/layout as tTMEM_LOADcS (BF16 view of FP32 registers).
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]

View File

@@ -0,0 +1,70 @@
"""
D1.3 SMEM-P coordinate mapping diagnostic.
Verifies that tTMEM_LOADcS coordinates and rP_bf16 values
correctly map to sP indices for the SMEM-P path.
Runs a minimal kernel that writes P to sP and reads it back.
"""
import torch, math
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
# Test: write known P values to sP using coordinate indexing, then read back
# via the PV MMA's A-operand fragment and verify.
def test_smem_p_coords():
print("=== SMEM-P Coordinate Diagnostic ===\n")
hd = 256; m = 128; s_k = 128; pv_n_tile = 256
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
# Simple reference: just compute Q@K^T softmax @ V
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
vf = v.float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
ref = torch.softmax(attn, dim=-1) @ vf # (128, 256)
# Run the kernel with use_smem_p=True
from dsv4.kernels.attention.fmha import FmhaKernel
kern = FmhaKernel(head_dim=hd, s_k=s_k)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v.unsqueeze(-1) # (128, 256, 1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
print('Compiling...', flush=True)
compiled = cute.compile(kern, mQ, mK, mV, mC, stream)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out - ref).abs().max().item()
print(f'hd=256 SMEM-P: cos {cos:.6f} max_abs {max_abs:.4f}')
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
# Also check: are the output values in a reasonable range?
print(f' out range: [{out.min().item():.4f}, {out.max().item():.4f}]')
print(f' ref range: [{ref.min().item():.4f}, {ref.max().item():.4f}]')
print(f' out has NaN: {torch.isnan(out).any().item()}')
print(f' out has inf: {torch.isinf(out).any().item()}')
if __name__ == '__main__':
test_smem_p_coords()