D1.5 debug: add NOOP rescale test (acc_scale=1.0) to isolate TMEM round-trip corruption
This commit is contained in:
@@ -21,7 +21,7 @@ import math
|
||||
|
||||
|
||||
class FmhaKernel:
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False):
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False, debug_noop_rescale=False):
|
||||
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
|
||||
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
|
||||
# When n_comp is None or 0, no offset (backward compatible).
|
||||
@@ -58,6 +58,8 @@ class FmhaKernel:
|
||||
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
|
||||
self.q_stage = 1
|
||||
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
|
||||
self.debug_noop_rescale = debug_noop_rescale # D1.5 debug: force acc_scale=1.0 in O rescale
|
||||
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
|
||||
@@ -556,6 +558,9 @@ class FmhaKernel:
|
||||
pv_done_bar.arrive_and_wait() # Wait for PV[kt-1]
|
||||
# Rescale O: load, multiply by acc_scale, store back to TMEM.
|
||||
# CUTLASS pattern: both copies use same tOtO_i (composition-tiled).
|
||||
rescale_factor = acc_scale
|
||||
if const_expr(self.debug_noop_rescale):
|
||||
rescale_factor = Float32(1.0)
|
||||
n_slices = self.head_dim // corr_tile_size
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, n_slices), self.qk_acc_dtype
|
||||
@@ -575,7 +580,7 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[k] = tTMrO_i[k] * acc_scale
|
||||
tTMrO_i[k] = tTMrO_i[k] * rescale_factor
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
|
||||
77
tests/unit/test_d15_noop_rescale.py
Normal file
77
tests/unit/test_d15_noop_rescale.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
D1.5 Debug: NO-OP TMEM round-trip test.
|
||||
Tests s_k=256 with rescale_factor=1.0 (NO-OP).
|
||||
If the round-trip itself is broken, even NO-OP will corrupt O.
|
||||
If it produces the same (wrong) result as without rescale, the round-trip works.
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def reference_attention(q, k, v, scale):
|
||||
qf = q.float(); kf = k.float()
|
||||
attn = qf @ kf.T * scale
|
||||
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(attn - attn_max)
|
||||
return attn_exp @ v.float()
|
||||
|
||||
|
||||
def run_test(s_k, debug_noop=False, label=""):
|
||||
hd = 64; m = 128; scale = 1.0 / math.sqrt(hd)
|
||||
torch.manual_seed(42)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False, debug_noop_rescale=debug_noop)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
v_t = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
|
||||
mRS = ct.from_dlpack(rs).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs))
|
||||
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ref = reference_attention(q[:, :, 0], k[:, :, 0], v, scale)
|
||||
out = c[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print(f'{label}: cos={cos:.6f} {"PASS" if cos >= 0.999 else "FAIL"}', flush=True)
|
||||
return out
|
||||
|
||||
|
||||
def test():
|
||||
# Test 1: s_k=128 baseline
|
||||
run_test(128, label="s_k=128 baseline")
|
||||
|
||||
# Test 2: s_k=256 WITH rescale (should be correct if TMEM round-trip works)
|
||||
run_test(256, label="s_k=256 with rescale")
|
||||
|
||||
# Test 3: s_k=256 NOOP rescale (acc_scale=1.0)
|
||||
# This should produce the same result as s_k=256 WITHOUT any rescale
|
||||
# (which is: O = P[0]*V[0] + P[1]*V[1] with no O rescale — mathematically wrong
|
||||
# but should be stable if TMEM round-trip doesn't corrupt)
|
||||
out_noop = run_test(256, debug_noop=True, label="s_k=256 NOOP rescale")
|
||||
|
||||
# Test 4: s_k=256 WITHOUT any rescale (old code path)
|
||||
# Compare with NOOP to see if TMEM round-trip itself corrupts
|
||||
# We can't easily disable the rescale in the current code,
|
||||
# but NOOP rescale with factor=1.0 is equivalent to a successful round-trip
|
||||
# followed by multiply-by-1. If the output matches a "no-rescale" baseline,
|
||||
# the TMEM round-trip is working correctly.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user