- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
63 lines
2.9 KiB
Python
63 lines
2.9 KiB
Python
"""
|
|
Diagnostic: PV with V = all ones (64, 128) MN-major.
|
|
O[m, d] should be sum_k P[m, k] for all d (all columns identical).
|
|
If columns differ, the PV MMA is reading V incorrectly.
|
|
Also test: PV with V = single element (d=0, k=0) = 1, rest 0.
|
|
O[m, 0] = P[m, 0], all other entries 0.
|
|
"""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
|
|
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
|
|
# Reuse the same kernel as test_diag_v_truncid.py but with different V
|
|
|
|
def run_test(v_data, desc, head_dim=64):
|
|
m, n = 128, 128
|
|
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
|
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_data.float()
|
|
ref = (qf @ kf.T).bfloat16().float() @ vf.T
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
# Import the kernel from test_diag_v_truncid
|
|
from test_diag_v_truncid import DiagVTruncIdKernel
|
|
kernel = DiagVTruncIdKernel(mma_tiler_mn=(128, 128), head_dim=head_dim)
|
|
print(f'Compiling {desc}...', flush=True)
|
|
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
|
print(f'Running {desc}...', flush=True)
|
|
compiled(mQ, mK, mV, mC, stream)
|
|
torch.cuda.synchronize()
|
|
out = c[:,:,0].float()
|
|
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
|
print(f'{desc}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
|
|
|
|
if cos < 0.99:
|
|
print(f' O[0,:5] = {out[0,:5].tolist()}')
|
|
print(f' ref[0,:5] = {ref[0,:5].tolist()}')
|
|
# Check if columns are all the same (for all-ones V)
|
|
if v_data.abs().sum() > 0:
|
|
col0 = out[:, 0]
|
|
all_same = all(torch.allclose(out[:, d], col0, atol=0.1) for d in range(min(8, head_dim)))
|
|
print(f' All columns same? {all_same}')
|
|
|
|
# Test 1: All ones V
|
|
v_ones = torch.ones(64, 128, dtype=torch.bfloat16, device='cuda')
|
|
run_test(v_ones, "All-ones V")
|
|
|
|
# Test 2: Single element V
|
|
v_single = torch.zeros(64, 128, dtype=torch.bfloat16, device='cuda')
|
|
v_single[0, 0] = 1.0
|
|
run_test(v_single, "Single-element V")
|