diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 01b55d17..5cf57188 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -21,7 +21,7 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False, debug_noop_rescale=False): # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to # positions >= n_comp. D3/D4 masks also only apply to SWA region. # When n_comp is None or 0, no offset (backward compatible). @@ -58,6 +58,8 @@ class FmhaKernel: self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd self.q_stage = 1 self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 + self.debug_noop_rescale = debug_noop_rescale # D1.5 debug: force acc_scale=1.0 in O rescale + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) @@ -556,6 +558,9 @@ class FmhaKernel: pv_done_bar.arrive_and_wait() # Wait for PV[kt-1] # Rescale O: load, multiply by acc_scale, store back to TMEM. # CUTLASS pattern: both copies use same tOtO_i (composition-tiled). + rescale_factor = acc_scale + if const_expr(self.debug_noop_rescale): + rescale_factor = Float32(1.0) n_slices = self.head_dim // corr_tile_size tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, n_slices), self.qk_acc_dtype @@ -575,7 +580,7 @@ class FmhaKernel: cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) cute.arch.fence_view_async_tmem_load() for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale + tTMrO_i[k] = tTMrO_i[k] * rescale_factor cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store() diff --git a/tests/unit/test_d15_noop_rescale.py b/tests/unit/test_d15_noop_rescale.py new file mode 100644 index 00000000..3dcc44fa --- /dev/null +++ b/tests/unit/test_d15_noop_rescale.py @@ -0,0 +1,77 @@ +""" +D1.5 Debug: NO-OP TMEM round-trip test. +Tests s_k=256 with rescale_factor=1.0 (NO-OP). +If the round-trip itself is broken, even NO-OP will corrupt O. +If it produces the same (wrong) result as without rescale, the round-trip works. +""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def reference_attention(q, k, v, scale): + qf = q.float(); kf = k.float() + attn = qf @ kf.T * scale + attn_max = attn.max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(attn - attn_max) + return attn_exp @ v.float() + + +def run_test(s_k, debug_noop=False, label=""): + hd = 64; m = 128; scale = 1.0 / math.sqrt(hd) + torch.manual_seed(42) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + + kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False, debug_noop_rescale=debug_noop) + pv_n_tile = kernel.pv_n_tile + c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + v_t = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse)) + mRS = ct.from_dlpack(rs).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs)) + + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, row_sums=mRS) + compiled(mQ, mK, mV, mC, stream, mLSE, row_sums=mRS) + torch.cuda.synchronize() + + ref = reference_attention(q[:, :, 0], k[:, :, 0], v, scale) + out = c[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + print(f'{label}: cos={cos:.6f} {"PASS" if cos >= 0.999 else "FAIL"}', flush=True) + return out + + +def test(): + # Test 1: s_k=128 baseline + run_test(128, label="s_k=128 baseline") + + # Test 2: s_k=256 WITH rescale (should be correct if TMEM round-trip works) + run_test(256, label="s_k=256 with rescale") + + # Test 3: s_k=256 NOOP rescale (acc_scale=1.0) + # This should produce the same result as s_k=256 WITHOUT any rescale + # (which is: O = P[0]*V[0] + P[1]*V[1] with no O rescale — mathematically wrong + # but should be stable if TMEM round-trip doesn't corrupt) + out_noop = run_test(256, debug_noop=True, label="s_k=256 NOOP rescale") + + # Test 4: s_k=256 WITHOUT any rescale (old code path) + # Compare with NOOP to see if TMEM round-trip itself corrupts + # We can't easily disable the rescale in the current code, + # but NOOP rescale with factor=1.0 is equivalent to a successful round-trip + # followed by multiply-by-1. If the output matches a "no-rescale" baseline, + # the TMEM round-trip is working correctly. + + +if __name__ == '__main__': + test()