Files
nvfp4-megamoe-kernel/tests/test_diag_permute.py
2026-05-21 05:08:57 +00:00

81 lines
3.7 KiB
Python

"""
Quick diagnostic: truncated identity V with 128x64 PV.
Check if output columns match a permutation of reference columns.
If O[m,d] = P[m, perm(d)], then the PV MMA is reading P from wrong TMEM addresses.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
# Reuse the DiagVTruncIdKernel from test_diag_v_truncid.py
# (just run it and do more analysis on the output)
# Actually, let me just re-run the truncid test and do the permutation analysis in Python
# First run the kernel, then analyze
# We already ran it and have the results. Let me just do the analysis with the numbers we have.
# O[0,:5] = [6.0625, 11.875, -9.5625, -4.6875, -14.9375]
# ref[0,:5] = [6.0625, 10.5625, 11.875, -11.75, -9.5625]
# P[0] (full Q@K^T row 0) needs to be computed
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kf = k[:,:,0].float()
P = (qf @ kf.T) # (128, 128) — the P matrix
# Now check: does O[0, d] = P[0, perm(d)] for some permutation?
# O[0,0] = 6.0625 → matches P[0,0] = 6.0625
# O[0,1] = 11.875 → matches P[0,2] = 11.875
# O[0,2] = -9.5625 → matches P[0,4] = -9.5625
# So O[0, d] = P[0, 2*d]? Let me check more.
O_row0 = [6.0625, 11.875, -9.5625, -4.6875, -14.9375]
P_row0 = P[0, :10].tolist()
print(f"P[0, :10] = {P_row0}")
print(f"O[0, :5] = {O_row0}")
# Check: O[0, d] = P[0, 2*d]?
for d in range(5):
print(f" O[0,{d}] = {O_row0[d]:.4f}, P[0,{2*d}] = {P_row0[2*d]:.4f}, match = {abs(O_row0[d] - P_row0[2*d]) < 0.01}")
# Also check full P row 0 vs O
# We can't get O without running the kernel again, but the pattern is clear:
# O[m, d] = P[m, 2*d] for the truncated identity V case
# This means the PV MMA is reading P from every other TMEM column
# Why 2*d? Because with (128,64) MMA, the A fragment reads TMEM with stride 2 in the K dimension.
# The (128,64,16) MMA atom has N=64, which means it reads 64 columns of P per K-tile
# But P has 128 columns. The MMA reads the first 64, but with the wrong stride.
#
# Actually, with (128,64,16) MMA:
# - A operand: (M=128, K=128) → MMA reads 128/16 = 8 K-tiles
# - Each K-tile reads P[:, k*16:(k+1)*16] = 16 columns of P
# - The A fragment for K-tile kb reads from TMEM column offset based on N_MMA
#
# The (128,64,16) MMA's TMEM A fragment layout might be:
# (128, N_MMA) where N_MMA relates to the N dimension of the MMA
# If N_MMA = 64 (half of 128), then P's 128 BF16 values in K are stored
# in 128 BF16 TMEM columns = 64 FP32 TMEM columns
# But the (128,64,16) A fragment might only address 32 FP32 TMEM columns
# because the MMA only uses 64 columns for the C output
# So P's 128 K values don't fit in 32 TMEM columns, and the layout is different
# The root cause: the (128,64) MMA's A fragment in TMEM packs 128 BF16 K values
# into fewer TMEM columns than the (128,128) MMA. The softmax packing writes P
# using the (128,128) layout, but the PV MMA reads with the (128,64) layout.
print("\n=== HYPOTHESIS ===")
print("The (128,64,16) MMA atom reads P from TMEM with a DIFFERENT layout")
print("than the softmax packing writes P with (QK C fragment layout).")
print("The (128,128,16) MMA atom's A fragment layout matches the QK C fragment layout,")
print("so the 128x128 case works. The (128,64,16) layout differs, causing the bug.")
print("Fix: softmax packing should write P using the PV MMA's A fragment layout.")