Files
nvfp4-megamoe-kernel/tests/debug_stages.py
biondizzle 97656a5cd1 Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong
Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
2026-05-20 20:26:25 +00:00

60 lines
2.8 KiB
Python

"""
Debug: test each stage independently.
1. Run Stage A (Q @ K^T only) — should give cosine 0.999
2. Run Stage B minimal (two MMAs, no softmax) — should give NaN or garbage
3. Run Stage B pipeline-only (pipeline but no ld/st) — should give NaN or garbage
4. Run Stage B full (identity softmax) — should give correct (Q@K^T)@V
"""
import torch
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref_qkt = qf @ kvf.T
ref_qktv = ref_qkt @ kvf
print(f"Q shape: {q.shape}, KV shape: {kv.shape}")
print(f"Q@K^T shape: {ref_qkt.shape}, (Q@K^T)@V shape: {ref_qktv.shape}")
print(f"Q@K^T range: [{ref_qkt.min():.2f}, {ref_qkt.max():.2f}]")
print(f"(Q@K^T)@V range: [{ref_qktv.min():.2f}, {ref_qktv.max():.2f}]")
# Test Stage A first
from test_stage_a_v2 import StageAQKTKernel
c_a = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c_a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_a))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel_a = StageAQKTKernel(mma_tiler_mn=(128, 128))
compiled_a = cute.compile(kernel_a, mQ, mK, mC, stream)
compiled_a(mQ, mK, mC, stream)
torch.cuda.synchronize()
out_a = c_a[:,:,0].float()
cos_a = torch.nn.functional.cosine_similarity(out_a.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item()
print(f"\nStage A (Q@K^T): cosine = {cos_a:.6f} {'' if cos_a > 0.99 else ''}")
# Test Stage B v7 (identity softmax)
from test_stage_b_v7 import StageBIdentitySoftmax
c_b = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
mC2 = ct.from_dlpack(c_b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_b))
kernel_b = StageBIdentitySoftmax(mma_tiler_mn=(128, 128))
compiled_b = cute.compile(kernel_b, mQ, mK, mC2, stream)
compiled_b(mQ, mK, mC2, stream)
torch.cuda.synchronize()
out_b = c_b[:,:,0].float()
cos_b = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qktv.flatten().unsqueeze(0)).item()
has_nan = torch.isnan(out_b).any().item()
print(f"Stage B (identity softmax): cosine = {cos_b:.6f}, has_nan = {has_nan} {'' if cos_b > 0.99 else ''}")
# Check: is the output close to Q@K^T (not Q@K^T@V)?
cos_b_qkt = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item()
print(f" vs Q@K^T: cosine = {cos_b_qkt:.6f} (should be ~0 if it's Q@K^T@V)")
print(f" Output range: [{out_b.nan_to_num().min():.2f}, {out_b.nan_to_num().max():.2f}]")