""" Debug: test each stage independently. 1. Run Stage A (Q @ K^T only) — should give cosine 0.999 2. Run Stage B minimal (two MMAs, no softmax) — should give NaN or garbage 3. Run Stage B pipeline-only (pipeline but no ld/st) — should give NaN or garbage 4. Run Stage B full (identity softmax) — should give correct (Q@K^T)@V """ import torch import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda torch.manual_seed(42) m, n, k = 128, 128, 128 q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') qf = q[:,:,0].float(); kvf = kv[:,:,0].float() ref_qkt = qf @ kvf.T ref_qktv = ref_qkt @ kvf print(f"Q shape: {q.shape}, KV shape: {kv.shape}") print(f"Q@K^T shape: {ref_qkt.shape}, (Q@K^T)@V shape: {ref_qktv.shape}") print(f"Q@K^T range: [{ref_qkt.min():.2f}, {ref_qkt.max():.2f}]") print(f"(Q@K^T)@V range: [{ref_qktv.min():.2f}, {ref_qktv.max():.2f}]") # Test Stage A first from test_stage_a_v2 import StageAQKTKernel c_a = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) mC = ct.from_dlpack(c_a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_a)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel_a = StageAQKTKernel(mma_tiler_mn=(128, 128)) compiled_a = cute.compile(kernel_a, mQ, mK, mC, stream) compiled_a(mQ, mK, mC, stream) torch.cuda.synchronize() out_a = c_a[:,:,0].float() cos_a = torch.nn.functional.cosine_similarity(out_a.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item() print(f"\nStage A (Q@K^T): cosine = {cos_a:.6f} {'✅' if cos_a > 0.99 else '❌'}") # Test Stage B v7 (identity softmax) from test_stage_b_v7 import StageBIdentitySoftmax c_b = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') mC2 = ct.from_dlpack(c_b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_b)) kernel_b = StageBIdentitySoftmax(mma_tiler_mn=(128, 128)) compiled_b = cute.compile(kernel_b, mQ, mK, mC2, stream) compiled_b(mQ, mK, mC2, stream) torch.cuda.synchronize() out_b = c_b[:,:,0].float() cos_b = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qktv.flatten().unsqueeze(0)).item() has_nan = torch.isnan(out_b).any().item() print(f"Stage B (identity softmax): cosine = {cos_b:.6f}, has_nan = {has_nan} {'✅' if cos_b > 0.99 else '❌'}") # Check: is the output close to Q@K^T (not Q@K^T@V)? cos_b_qkt = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item() print(f" vs Q@K^T: cosine = {cos_b_qkt:.6f} (should be ~0 if it's Q@K^T@V)") print(f" Output range: [{out_b.nan_to_num().min():.2f}, {out_b.nan_to_num().max():.2f}]")