D1.3: Add SMEM-P vs TMEM-P comparison test

This commit is contained in:
2026-05-24 00:10:18 +00:00
parent d56e5601bb
commit 0fc6530f3f

View File

@@ -0,0 +1,88 @@
"""
D1.3 SMEM-P: Focused test of coordinate-indexed SMEM-P write.
Tests at hd=64 with SMEM-P, compares to TMEM-P result.
Also tests hd=128 and hd=256.
"""
import torch, math
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test_hd(hd, use_smem_p, s_k=128):
pv_n = min(hd, 256)
q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, pv_n, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(128, pv_n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
vf = v.float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
ref = torch.softmax(attn, dim=-1) @ vf
kern = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=True)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
mode = "SMEM-P" if use_smem_p else "TMEM-P"
print(f'Compiling hd={hd} {mode}...', flush=True)
compiled = cute.compile(kern, mQ, mK, mV, mC, stream)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out - ref).abs().max().item()
has_nan = torch.isnan(out).any().item()
has_inf = torch.isinf(out).any().item()
print(f' hd={hd} {mode}: cos={cos:.6f} max_abs={max_abs:.6f} NaN={has_nan} Inf={has_inf}')
# Print first few values for comparison
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
print()
return cos
if __name__ == '__main__':
print("=== SMEM-P vs TMEM-P Comparison ===\n")
# Baseline: TMEM-P at hd=64 (proven path)
cos_tmem = test_hd(64, use_smem_p=False)
# Test: SMEM-P at hd=64 (should match TMEM-P if correct)
cos_smem = test_hd(64, use_smem_p=True)
# Test: SMEM-P at hd=128
cos_128 = test_hd(128, use_smem_p=True)
# Test: SMEM-P at hd=256
cos_256 = test_hd(256, use_smem_p=True)
print("=== Summary ===")
print(f"hd=64 TMEM-P: cos={cos_tmem:.6f} (baseline)")
print(f"hd=64 SMEM-P: cos={cos_smem:.6f} (should match TMEM-P)")
print(f"hd=128 SMEM-P: cos={cos_128:.6f}")
print(f"hd=256 SMEM-P: cos={cos_256:.6f}")
if abs(cos_tmem - cos_smem) < 0.01:
print("\n✅ SMEM-P matches TMEM-P at hd=64 — coordinate mapping is correct!")
else:
print(f"\n⚠️ SMEM-P differs from TMEM-P by {abs(cos_tmem - cos_smem):.6f} — coordinate mapping has issues")