- Both load and store atoms built from SAME tOtO_i (composition-tiled) - Same Repetition(corr_tile_size=16) for both copies - pv_done_bar synchronization between MMA and softmax warps - acc_scale computed per kt iteration, used to rescale O in TMEM - const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128 - New test: test_d15_in_kernel_rescale.py (s_k=128/256/384) - Minimal roundtrip test: test_tmem_roundtrip_minimal.py
153 lines
6.4 KiB
Python
153 lines
6.4 KiB
Python
"""
|
|
D1.5 Phase 4: Test in-kernel O rescale for multi-KV-tile FMHA.
|
|
|
|
Tests the CUTLASS correction_rescale pattern:
|
|
- Both load and store atoms built from the SAME tOtO_i (composition-tiled)
|
|
- Same Repetition(corr_tile_size=16) for both
|
|
- Rescale O in TMEM between PV iterations
|
|
|
|
Compares against:
|
|
1. FP32 reference (ground truth)
|
|
2. Python KV merge (proven correct, cos 0.999998)
|
|
3. s_k=128 baseline (no rescale, regression check)
|
|
"""
|
|
import torch, math
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
|
|
def reference_attention(q, k, v, scale):
|
|
"""FP32 reference: returns un-normalized O."""
|
|
qf = q.float()
|
|
kf = k.float()
|
|
attn = qf @ kf.T * scale
|
|
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
|
attn_exp = torch.exp(attn - attn_max)
|
|
ref_unnorm = attn_exp @ v.float()
|
|
return ref_unnorm
|
|
|
|
|
|
def run_fmha(q, k, v, head_dim, s_k, pv_n_tile, use_smem_p, stream, lse_tensor, row_sums_tensor):
|
|
"""Run FMHA kernel and return output tensor."""
|
|
m = q.shape[0]
|
|
v_tile = v[:, 0:pv_n_tile].contiguous()
|
|
v_kernel = v_tile.unsqueeze(-1)
|
|
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
|
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
|
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
|
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
|
|
|
|
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k, use_smem_p=use_smem_p, normalize=False)
|
|
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
|
|
compiled(mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
|
|
return c_tile, lse_tensor, row_sums_tensor, kernel
|
|
|
|
|
|
def test():
|
|
hd = 64
|
|
m = 128
|
|
scale = 1.0 / math.sqrt(hd)
|
|
torch.manual_seed(42)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# ===== Test 1: s_k=128 baseline (no rescale) =====
|
|
s_k1 = 128
|
|
k1 = torch.randn(s_k1, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v1 = torch.randn(s_k1, hd, dtype=torch.bfloat16, device='cuda')
|
|
lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
|
|
# Need a dummy run to get pv_n_tile
|
|
kernel0 = FmhaKernel(head_dim=hd, s_k=s_k1, use_smem_p=False, normalize=False)
|
|
pv_n_tile = kernel0.pv_n_tile
|
|
|
|
c1, lse1, rs1, _ = run_fmha(q, k1, v1, hd, s_k1, pv_n_tile, False, stream, lse1, rs1)
|
|
torch.cuda.synchronize()
|
|
|
|
ref1 = reference_attention(q[:, :, 0], k1[:, :, 0], v1, scale)
|
|
cos1 = torch.nn.functional.cosine_similarity(
|
|
c1[:, :, 0].float().flatten().unsqueeze(0), ref1.flatten().unsqueeze(0)
|
|
).item()
|
|
status1 = "PASS" if cos1 >= 0.999 else "FAIL"
|
|
print(f'Test 1: s_k=128 baseline: cos={cos1:.6f} {status1}', flush=True)
|
|
|
|
# ===== Test 2: s_k=256 with in-kernel rescale (CUTLASS correction_rescale) =====
|
|
s_k2 = 256
|
|
k2 = torch.randn(s_k2, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v2 = torch.randn(s_k2, hd, dtype=torch.bfloat16, device='cuda')
|
|
lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
|
|
c2, lse2, rs2, _ = run_fmha(q, k2, v2, hd, s_k2, pv_n_tile, False, stream, lse2, rs2)
|
|
torch.cuda.synchronize()
|
|
|
|
ref2 = reference_attention(q[:, :, 0], k2[:, :, 0], v2, scale)
|
|
cos2 = torch.nn.functional.cosine_similarity(
|
|
c2[:, :, 0].float().flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)
|
|
).item()
|
|
status2 = "PASS" if cos2 >= 0.999 else "FAIL"
|
|
print(f'Test 2: s_k=256 in-kernel rescale: cos={cos2:.6f} {status2}', flush=True)
|
|
|
|
# ===== Test 3: Python KV merge (oracle) =====
|
|
c_s0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
c_s0, lse_s0, rs_s0, _ = run_fmha(q, k2[:128], v2[:128], hd, 128, pv_n_tile, False, stream, lse_s0, rs_s0)
|
|
|
|
c_s1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
c_s1, lse_s1, rs_s1, _ = run_fmha(q, k2[128:], v2[128:], hd, 128, pv_n_tile, False, stream, lse_s1, rs_s1)
|
|
torch.cuda.synchronize()
|
|
|
|
# D5 merge: O = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i))
|
|
o0 = c_s0[:, :, 0].float()
|
|
o1 = c_s1[:, :, 0].float()
|
|
r0 = rs_s0[:, 0, 0].float()
|
|
r1 = rs_s1[:, 0, 0].float()
|
|
l0 = lse_s0[:, 0, 0].float()
|
|
l1 = lse_s1[:, 0, 0].float()
|
|
o0_norm = o0 / r0.unsqueeze(1).clamp(min=1e-30)
|
|
o1_norm = o1 / r1.unsqueeze(1).clamp(min=1e-30)
|
|
w0 = torch.exp(l0).unsqueeze(1)
|
|
w1 = torch.exp(l1).unsqueeze(1)
|
|
oracle = (w0 * o0_norm + w1 * o1_norm) / (w0 + w1)
|
|
|
|
cos_oracle = torch.nn.functional.cosine_similarity(
|
|
oracle.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)
|
|
).item()
|
|
print(f'Oracle: Python KV merge: cos={cos_oracle:.6f}', flush=True)
|
|
|
|
# ===== Test 4: s_k=384 (3 KV tiles) =====
|
|
s_k3 = 384
|
|
k3 = torch.randn(s_k3, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v3 = torch.randn(s_k3, hd, dtype=torch.bfloat16, device='cuda')
|
|
lse3 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs3 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
|
|
c3, lse3, rs3, _ = run_fmha(q, k3, v3, hd, s_k3, pv_n_tile, False, stream, lse3, rs3)
|
|
torch.cuda.synchronize()
|
|
|
|
ref3 = reference_attention(q[:, :, 0], k3[:, :, 0], v3, scale)
|
|
cos3 = torch.nn.functional.cosine_similarity(
|
|
c3[:, :, 0].float().flatten().unsqueeze(0), ref3.flatten().unsqueeze(0)
|
|
).item()
|
|
status3 = "PASS" if cos3 >= 0.999 else "FAIL"
|
|
print(f'Test 4: s_k=384 in-kernel rescale: cos={cos3:.6f} {status3}', flush=True)
|
|
|
|
# ===== Summary =====
|
|
all_pass = cos1 >= 0.999 and cos2 >= 0.999 and cos3 >= 0.999
|
|
print(f'\nSummary: {"ALL PASS ✅" if all_pass else "SOME FAIL ❌"}', flush=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|