D1.5 debug: add NOOP rescale test (acc_scale=1.0) to isolate TMEM round-trip corruption

This commit is contained in:
2026-05-26 20:28:55 +00:00
parent c3648e4ebf
commit 3be708d923
2 changed files with 84 additions and 2 deletions

View File

@@ -21,7 +21,7 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False, debug_noop_rescale=False):
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
# When n_comp is None or 0, no offset (backward compatible).
@@ -58,6 +58,8 @@ class FmhaKernel:
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
self.q_stage = 1
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
self.debug_noop_rescale = debug_noop_rescale # D1.5 debug: force acc_scale=1.0 in O rescale
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
@@ -556,6 +558,9 @@ class FmhaKernel:
pv_done_bar.arrive_and_wait() # Wait for PV[kt-1]
# Rescale O: load, multiply by acc_scale, store back to TMEM.
# CUTLASS pattern: both copies use same tOtO_i (composition-tiled).
rescale_factor = acc_scale
if const_expr(self.debug_noop_rescale):
rescale_factor = Float32(1.0)
n_slices = self.head_dim // corr_tile_size
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, n_slices), self.qk_acc_dtype
@@ -575,7 +580,7 @@ class FmhaKernel:
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * acc_scale
tTMrO_i[k] = tTMrO_i[k] * rescale_factor
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()

View File

@@ -0,0 +1,77 @@
"""
D1.5 Debug: NO-OP TMEM round-trip test.
Tests s_k=256 with rescale_factor=1.0 (NO-OP).
If the round-trip itself is broken, even NO-OP will corrupt O.
If it produces the same (wrong) result as without rescale, the round-trip works.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_attention(q, k, v, scale):
qf = q.float(); kf = k.float()
attn = qf @ kf.T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
return attn_exp @ v.float()
def run_test(s_k, debug_noop=False, label=""):
hd = 64; m = 128; scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False, debug_noop_rescale=debug_noop)
pv_n_tile = kernel.pv_n_tile
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
v_t = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
mRS = ct.from_dlpack(rs).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs))
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
compiled(mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
torch.cuda.synchronize()
ref = reference_attention(q[:, :, 0], k[:, :, 0], v, scale)
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'{label}: cos={cos:.6f} {"PASS" if cos >= 0.999 else "FAIL"}', flush=True)
return out
def test():
# Test 1: s_k=128 baseline
run_test(128, label="s_k=128 baseline")
# Test 2: s_k=256 WITH rescale (should be correct if TMEM round-trip works)
run_test(256, label="s_k=256 with rescale")
# Test 3: s_k=256 NOOP rescale (acc_scale=1.0)
# This should produce the same result as s_k=256 WITHOUT any rescale
# (which is: O = P[0]*V[0] + P[1]*V[1] with no O rescale — mathematically wrong
# but should be stable if TMEM round-trip doesn't corrupt)
out_noop = run_test(256, debug_noop=True, label="s_k=256 NOOP rescale")
# Test 4: s_k=256 WITHOUT any rescale (old code path)
# Compare with NOOP to see if TMEM round-trip itself corrupts
# We can't easily disable the rescale in the current code,
# but NOOP rescale with factor=1.0 is equivalent to a successful round-trip
# followed by multiply-by-1. If the output matches a "no-rescale" baseline,
# the TMEM round-trip is working correctly.
if __name__ == '__main__':
test()