D1.3: Add SMEM-P vs TMEM-P comparison test
This commit is contained in:
88
tests/unit/test_d1_3_smem_vs_tmem.py
Normal file
88
tests/unit/test_d1_3_smem_vs_tmem.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
D1.3 SMEM-P: Focused test of coordinate-indexed SMEM-P write.
|
||||
Tests at hd=64 with SMEM-P, compares to TMEM-P result.
|
||||
Also tests hd=128 and hd=256.
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16
|
||||
from cutlass.utils import LayoutEnum
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def test_hd(hd, use_smem_p, s_k=128):
|
||||
pv_n = min(hd, 256)
|
||||
q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, pv_n, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(128, pv_n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
vf = v.float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
ref = torch.softmax(attn, dim=-1) @ vf
|
||||
|
||||
kern = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=True)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
v_tile = v.unsqueeze(-1)
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
|
||||
mode = "SMEM-P" if use_smem_p else "TMEM-P"
|
||||
print(f'Compiling hd={hd} {mode}...', flush=True)
|
||||
compiled = cute.compile(kern, mQ, mK, mV, mC, stream)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (out - ref).abs().max().item()
|
||||
has_nan = torch.isnan(out).any().item()
|
||||
has_inf = torch.isinf(out).any().item()
|
||||
|
||||
print(f' hd={hd} {mode}: cos={cos:.6f} max_abs={max_abs:.6f} NaN={has_nan} Inf={has_inf}')
|
||||
|
||||
# Print first few values for comparison
|
||||
print(f' out[0,:4]={out[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
print()
|
||||
|
||||
return cos
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=== SMEM-P vs TMEM-P Comparison ===\n")
|
||||
|
||||
# Baseline: TMEM-P at hd=64 (proven path)
|
||||
cos_tmem = test_hd(64, use_smem_p=False)
|
||||
|
||||
# Test: SMEM-P at hd=64 (should match TMEM-P if correct)
|
||||
cos_smem = test_hd(64, use_smem_p=True)
|
||||
|
||||
# Test: SMEM-P at hd=128
|
||||
cos_128 = test_hd(128, use_smem_p=True)
|
||||
|
||||
# Test: SMEM-P at hd=256
|
||||
cos_256 = test_hd(256, use_smem_p=True)
|
||||
|
||||
print("=== Summary ===")
|
||||
print(f"hd=64 TMEM-P: cos={cos_tmem:.6f} (baseline)")
|
||||
print(f"hd=64 SMEM-P: cos={cos_smem:.6f} (should match TMEM-P)")
|
||||
print(f"hd=128 SMEM-P: cos={cos_128:.6f}")
|
||||
print(f"hd=256 SMEM-P: cos={cos_256:.6f}")
|
||||
|
||||
if abs(cos_tmem - cos_smem) < 0.01:
|
||||
print("\n✅ SMEM-P matches TMEM-P at hd=64 — coordinate mapping is correct!")
|
||||
else:
|
||||
print(f"\n⚠️ SMEM-P differs from TMEM-P by {abs(cos_tmem - cos_smem):.6f} — coordinate mapping has issues")
|
||||
Reference in New Issue
Block a user