Files
nvfp4-megamoe-kernel/tests/archive/test_error_pattern.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

86 lines
3.7 KiB
Python

"""
BF16 Packing Diagnostic: Run identity softmax with K=V=randn,
compare output vs reference to identify the error pattern.
"""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from test_stage_b_v7 import StageBIdentitySoftmax
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf # identity softmax: (Q @ K^T) @ V
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'\nCosine: {cos:.6f}')
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
# Key diagnostic: compare Q@K^T stage (which we know is correct) vs PV stage
# If Q@K^T is correct but PV is wrong, the output will show the PV error pattern
# With identity softmax, P = Q@K^T. So output = P @ V = (Q@K^T) @ V
# If V is being read wrong, the output will be P @ V_permuted
# Check: is the output a permutation of the reference?
# Sort both and compare
out_sorted = out.flatten().sort()[0]
ref_sorted = ref.flatten().sort()[0]
cos_sorted = torch.nn.functional.cosine_similarity(out_sorted.unsqueeze(0), ref_sorted.unsqueeze(0)).item()
print(f'\nCosine after sorting: {cos_sorted:.6f}')
print(f'(If ~1.0, output is a permutation of reference)')
# Check: does output match P @ V^T? (V transposed)
ref_vt = qf @ kvf.T @ kvf.T # (Q@K^T) @ V^T
cos_vt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_vt.flatten().unsqueeze(0)).item()
print(f'Cosine with P @ V^T: {cos_vt:.6f}')
# Check: does output match P^T @ V? (P transposed)
# P = Q@K^T, so P^T = K@Q^T
ref_pt = kvf @ qf.T @ kvf
cos_pt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_pt.flatten().unsqueeze(0)).item()
print(f'Cosine with P^T @ V: {cos_pt:.6f}')
# Check: is output simply the Q@K^T scores (MMA2 produced identity)?
# If MMA2 didn't run or produced P unchanged, output = P = Q@K^T
qkt = qf @ kvf.T
cos_qkt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), qkt.flatten().unsqueeze(0)).item()
print(f'Cosine with just Q@K^T: {cos_qkt:.6f}')
# Check: all output rows identical? (means P has identical rows, like all-ones)
all_same = torch.allclose(out[0], out[1], atol=1e-3)
print(f'All output rows identical: {all_same}')
if not all_same:
cos_r01 = torch.nn.functional.cosine_similarity(out[0].unsqueeze(0), out[1].unsqueeze(0)).item()
print(f'Cosine between row 0 and row 1: {cos_r01:.6f}')
# Check: is output just V (P=I case)?
cos_v = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), kvf.flatten().unsqueeze(0)).item()
print(f'Cosine with V alone: {cos_v:.6f}')
# Print some output statistics
print(f'\nOutput stats: min={out.min().item():.4f}, max={out.max().item():.4f}, mean={out.mean().item():.4f}')
print(f'Ref stats: min={ref.min().item():.4f}, max={ref.max().item():.4f}, mean={ref.mean().item():.4f}')