60 lines
2.8 KiB
Python
60 lines
2.8 KiB
Python
|
|
"""
|
||
|
|
Debug: test each stage independently.
|
||
|
|
1. Run Stage A (Q @ K^T only) — should give cosine 0.999
|
||
|
|
2. Run Stage B minimal (two MMAs, no softmax) — should give NaN or garbage
|
||
|
|
3. Run Stage B pipeline-only (pipeline but no ld/st) — should give NaN or garbage
|
||
|
|
4. Run Stage B full (identity softmax) — should give correct (Q@K^T)@V
|
||
|
|
"""
|
||
|
|
import torch
|
||
|
|
import cutlass.cute as cute
|
||
|
|
import cutlass.torch as ct
|
||
|
|
import cuda.bindings.driver as cuda
|
||
|
|
|
||
|
|
torch.manual_seed(42)
|
||
|
|
m, n, k = 128, 128, 128
|
||
|
|
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
|
||
|
|
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||
|
|
ref_qkt = qf @ kvf.T
|
||
|
|
ref_qktv = ref_qkt @ kvf
|
||
|
|
|
||
|
|
print(f"Q shape: {q.shape}, KV shape: {kv.shape}")
|
||
|
|
print(f"Q@K^T shape: {ref_qkt.shape}, (Q@K^T)@V shape: {ref_qktv.shape}")
|
||
|
|
print(f"Q@K^T range: [{ref_qkt.min():.2f}, {ref_qkt.max():.2f}]")
|
||
|
|
print(f"(Q@K^T)@V range: [{ref_qktv.min():.2f}, {ref_qktv.max():.2f}]")
|
||
|
|
|
||
|
|
# Test Stage A first
|
||
|
|
from test_stage_a_v2 import StageAQKTKernel
|
||
|
|
c_a = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||
|
|
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||
|
|
mC = ct.from_dlpack(c_a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_a))
|
||
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||
|
|
|
||
|
|
kernel_a = StageAQKTKernel(mma_tiler_mn=(128, 128))
|
||
|
|
compiled_a = cute.compile(kernel_a, mQ, mK, mC, stream)
|
||
|
|
compiled_a(mQ, mK, mC, stream)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
out_a = c_a[:,:,0].float()
|
||
|
|
cos_a = torch.nn.functional.cosine_similarity(out_a.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item()
|
||
|
|
print(f"\nStage A (Q@K^T): cosine = {cos_a:.6f} {'✅' if cos_a > 0.99 else '❌'}")
|
||
|
|
|
||
|
|
# Test Stage B v7 (identity softmax)
|
||
|
|
from test_stage_b_v7 import StageBIdentitySoftmax
|
||
|
|
c_b = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
mC2 = ct.from_dlpack(c_b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_b))
|
||
|
|
kernel_b = StageBIdentitySoftmax(mma_tiler_mn=(128, 128))
|
||
|
|
compiled_b = cute.compile(kernel_b, mQ, mK, mC2, stream)
|
||
|
|
compiled_b(mQ, mK, mC2, stream)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
out_b = c_b[:,:,0].float()
|
||
|
|
cos_b = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qktv.flatten().unsqueeze(0)).item()
|
||
|
|
has_nan = torch.isnan(out_b).any().item()
|
||
|
|
print(f"Stage B (identity softmax): cosine = {cos_b:.6f}, has_nan = {has_nan} {'✅' if cos_b > 0.99 else '❌'}")
|
||
|
|
|
||
|
|
# Check: is the output close to Q@K^T (not Q@K^T@V)?
|
||
|
|
cos_b_qkt = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item()
|
||
|
|
print(f" vs Q@K^T: cosine = {cos_b_qkt:.6f} (should be ~0 if it's Q@K^T@V)")
|
||
|
|
print(f" Output range: [{out_b.nan_to_num().min():.2f}, {out_b.nan_to_num().max():.2f}]")
|