- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
86 lines
3.7 KiB
Python
86 lines
3.7 KiB
Python
"""
|
|
BF16 Packing Diagnostic: Run identity softmax with K=V=randn,
|
|
compare output vs reference to identify the error pattern.
|
|
"""
|
|
import torch, sys
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
|
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from test_stage_b_v7 import StageBIdentitySoftmax
|
|
|
|
torch.manual_seed(42)
|
|
m, n, k = 128, 128, 128
|
|
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
qf = q[:,:,0].float()
|
|
kvf = kv[:,:,0].float()
|
|
ref = qf @ kvf.T @ kvf # identity softmax: (Q @ K^T) @ V
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
|
print('Compiling...', flush=True)
|
|
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
|
print('Running...', flush=True)
|
|
compiled(mQ, mK, mC, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
out = c[:,:,0].float()
|
|
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
|
|
|
print(f'\nCosine: {cos:.6f}')
|
|
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
|
|
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
|
|
|
|
# Key diagnostic: compare Q@K^T stage (which we know is correct) vs PV stage
|
|
# If Q@K^T is correct but PV is wrong, the output will show the PV error pattern
|
|
# With identity softmax, P = Q@K^T. So output = P @ V = (Q@K^T) @ V
|
|
# If V is being read wrong, the output will be P @ V_permuted
|
|
|
|
# Check: is the output a permutation of the reference?
|
|
# Sort both and compare
|
|
out_sorted = out.flatten().sort()[0]
|
|
ref_sorted = ref.flatten().sort()[0]
|
|
cos_sorted = torch.nn.functional.cosine_similarity(out_sorted.unsqueeze(0), ref_sorted.unsqueeze(0)).item()
|
|
print(f'\nCosine after sorting: {cos_sorted:.6f}')
|
|
print(f'(If ~1.0, output is a permutation of reference)')
|
|
|
|
# Check: does output match P @ V^T? (V transposed)
|
|
ref_vt = qf @ kvf.T @ kvf.T # (Q@K^T) @ V^T
|
|
cos_vt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_vt.flatten().unsqueeze(0)).item()
|
|
print(f'Cosine with P @ V^T: {cos_vt:.6f}')
|
|
|
|
# Check: does output match P^T @ V? (P transposed)
|
|
# P = Q@K^T, so P^T = K@Q^T
|
|
ref_pt = kvf @ qf.T @ kvf
|
|
cos_pt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_pt.flatten().unsqueeze(0)).item()
|
|
print(f'Cosine with P^T @ V: {cos_pt:.6f}')
|
|
|
|
# Check: is output simply the Q@K^T scores (MMA2 produced identity)?
|
|
# If MMA2 didn't run or produced P unchanged, output = P = Q@K^T
|
|
qkt = qf @ kvf.T
|
|
cos_qkt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), qkt.flatten().unsqueeze(0)).item()
|
|
print(f'Cosine with just Q@K^T: {cos_qkt:.6f}')
|
|
|
|
# Check: all output rows identical? (means P has identical rows, like all-ones)
|
|
all_same = torch.allclose(out[0], out[1], atol=1e-3)
|
|
print(f'All output rows identical: {all_same}')
|
|
if not all_same:
|
|
cos_r01 = torch.nn.functional.cosine_similarity(out[0].unsqueeze(0), out[1].unsqueeze(0)).item()
|
|
print(f'Cosine between row 0 and row 1: {cos_r01:.6f}')
|
|
|
|
# Check: is output just V (P=I case)?
|
|
cos_v = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), kvf.flatten().unsqueeze(0)).item()
|
|
print(f'Cosine with V alone: {cos_v:.6f}')
|
|
|
|
# Print some output statistics
|
|
print(f'\nOutput stats: min={out.min().item():.4f}, max={out.max().item():.4f}, mean={out.mean().item():.4f}')
|
|
print(f'Ref stats: min={ref.min().item():.4f}, max={ref.max().item():.4f}, mean={ref.mean().item():.4f}')
|