From 42c5793add89c837d76782b4eca3f64d14ebd3ea Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 20:45:58 +0000 Subject: [PATCH] D1.5: Add isolated round-trip test comparing s_k=128 vs s_k=256 with NOOP rescale --- dsv4/kernels/attention/fmha.py | 22 ++++-- tests/unit/test_d15_roundtrip_iso.py | 108 +++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 7 deletions(-) create mode 100644 tests/unit/test_d15_roundtrip_iso.py diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 008c31f7..761fdaeb 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -418,18 +418,26 @@ class FmhaKernel: # sub-tiles, and building BOTH copies from the SAME tensor, ensures the # column mappings agree on round-trip. # ============================================================ - corr_tile_size = 32 # Must be power of 2, divides head_dim. Try 32 instead of 16. - tOtO_i_layout = cute.composition( - tOtO0.layout, cute.make_layout((128, corr_tile_size)) - ) + corr_tile_size = 16 # Must be power of 2, divides head_dim + # Try both composition and raw layout + use_comp = True + if const_expr(use_comp): + tOtO_i_layout = cute.composition( + tOtO0.layout, cute.make_layout((128, corr_tile_size)) + ) + else: + tOtO_i_layout = tOtO0.layout tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) # Coordinate tensor for O (needed for partition_D of load) cO = cute.make_identity_tensor((128, self.head_dim)) tOcO = pv_thr.partition_C(cO) - tOcO_i_layout = cute.composition( - tOcO.layout, cute.make_layout((128, corr_tile_size)) - ) + if const_expr(use_comp): + tOcO_i_layout = cute.composition( + tOcO.layout, cute.make_layout((128, corr_tile_size)) + ) + else: + tOcO_i_layout = tOcO.layout tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) tmem_load_o_atom = cute.make_copy_atom( diff --git a/tests/unit/test_d15_roundtrip_iso.py b/tests/unit/test_d15_roundtrip_iso.py new file mode 100644 index 00000000..c919a4b3 --- /dev/null +++ b/tests/unit/test_d15_roundtrip_iso.py @@ -0,0 +1,108 @@ +""" +D1.5 Minimal TMEM round-trip test within FMHA. + +Tests: run FMHA with s_k=128, then add a NOOP correction_rescale +(multiply by 1.0) BETWEEN the softmax and the epilogue. +If the output is bitwise identical to without rescale, the round-trip works. +If it differs, the round-trip corrupts data. + +This test directly compares: same kernel, same data, WITH and WITHOUT +the O rescale step (forced to NOOP). +""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test(): + hd = 64; m = 128; s_k = 128; scale = 1.0 / math.sqrt(hd) + torch.manual_seed(42) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + + # Run WITHOUT rescale (s_k=128, n_kv_tiles=1, rescale code is dead) + kernel1 = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False) + pv_n_tile = kernel1.pv_n_tile + c1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + v_t = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t)) + mC1 = ct.from_dlpack(c1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c1)) + mLSE1 = ct.from_dlpack(lse1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse1)) + mRS1 = ct.from_dlpack(rs1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs1)) + + print('Compiling s_k=128 (no rescale)...', flush=True) + comp1 = cute.compile(kernel1, mQ, mK, mV, mC1, stream, mLSE1, row_sums=mRS1) + comp1(mQ, mK, mV, mC1, stream, mLSE1, row_sums=mRS1) + torch.cuda.synchronize() + + # Run WITH NOOP rescale (s_k=256, but with debug_noop_rescale=True, + # using same K/V repeated to make the second segment identical) + # This triggers the O rescale code path (n_kv_tiles=2) but with factor=1.0 + k2 = torch.cat([k, k], dim=0) + v2 = torch.cat([v, v], dim=0) + + kernel2 = FmhaKernel(head_dim=hd, s_k=256, use_smem_p=False, normalize=False, debug_noop_rescale=True) + c2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2)) + v2_t = v2[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + mV2 = ct.from_dlpack(v2_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_t)) + mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2)) + mLSE2 = ct.from_dlpack(lse2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse2)) + mRS2 = ct.from_dlpack(rs2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs2)) + + print('Compiling s_k=256 (NOOP rescale)...', flush=True) + comp2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2) + comp2(mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2) + torch.cuda.synchronize() + + # FP32 reference + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + attn = qf @ kf.T * scale + attn_max = attn.max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(attn - attn_max) + ref_norm = (attn_exp @ v.float()) / attn_exp.sum(dim=-1, keepdim=True) + + # Normalize both outputs + out1 = c1[:, :, 0].float() / rs1[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30) + out2 = c2[:, :, 0].float() / rs2[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30) + + cos1 = torch.nn.functional.cosine_similarity(out1.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item() + cos2 = torch.nn.functional.cosine_similarity(out2.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item() + cos_12 = torch.nn.functional.cosine_similarity(out1.flatten().unsqueeze(0), out2.flatten().unsqueeze(0)).item() + + # Also check unnormalized + unnorm1 = c1[:, :, 0].float() + unnorm2 = c2[:, :, 0].float() + # With identical segments, s_k=256 unnorm should be 2x s_k=128 unnorm + # (row_sum also doubles, so normalized result is the same) + ratio = unnorm2 / unnorm1.clamp(min=1e-10) + print(f'Unnorm ratio (should be ~2.0): mean={ratio.mean().item():.4f} std={ratio.std().item():.4f}') + + print(f's_k=128 (no rescale): cos={cos1:.6f}') + print(f's_k=256 (NOOP rescale): cos={cos2:.6f}') + print(f's_k=128 vs s_k=256: cos={cos_12:.6f}') + + # Bitwise comparison on unnormalized O + # s_k=256 with identical segments: O_unnorm = 2 * O_unnorm_128 + expected_unnorm2 = 2.0 * unnorm1 + bit_cos = torch.nn.functional.cosine_similarity(unnorm2.flatten().unsqueeze(0), expected_unnorm2.flatten().unsqueeze(0)).item() + max_diff = (unnorm2 - expected_unnorm2).abs().max().item() + print(f'Unnorm: expected 2x, got cos={bit_cos:.6f} max_diff={max_diff:.4f}') + + +if __name__ == '__main__': + test()