Files
nvfp4-megamoe-kernel/tests/archive/test_diag_v_ones.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

63 lines
2.9 KiB
Python

"""
Diagnostic: PV with V = all ones (64, 128) MN-major.
O[m, d] should be sum_k P[m, k] for all d (all columns identical).
If columns differ, the PV MMA is reading V incorrectly.
Also test: PV with V = single element (d=0, k=0) = 1, rest 0.
O[m, 0] = P[m, 0], all other entries 0.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
# Reuse the same kernel as test_diag_v_truncid.py but with different V
def run_test(v_data, desc, head_dim=64):
m, n = 128, 128
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_data.float()
ref = (qf @ kf.T).bfloat16().float() @ vf.T
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Import the kernel from test_diag_v_truncid
from test_diag_v_truncid import DiagVTruncIdKernel
kernel = DiagVTruncIdKernel(mma_tiler_mn=(128, 128), head_dim=head_dim)
print(f'Compiling {desc}...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f'Running {desc}...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'{desc}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' O[0,:5] = {out[0,:5].tolist()}')
print(f' ref[0,:5] = {ref[0,:5].tolist()}')
# Check if columns are all the same (for all-ones V)
if v_data.abs().sum() > 0:
col0 = out[:, 0]
all_same = all(torch.allclose(out[:, d], col0, atol=0.1) for d in range(min(8, head_dim)))
print(f' All columns same? {all_same}')
# Test 1: All ones V
v_ones = torch.ones(64, 128, dtype=torch.bfloat16, device='cuda')
run_test(v_ones, "All-ones V")
# Test 2: Single element V
v_single = torch.zeros(64, 128, dtype=torch.bfloat16, device='cuda')
v_single[0, 0] = 1.0
run_test(v_single, "Single-element V")