81 lines
3.7 KiB
Python
81 lines
3.7 KiB
Python
"""
|
|
Quick diagnostic: truncated identity V with 128x64 PV.
|
|
Check if output columns match a permutation of reference columns.
|
|
If O[m,d] = P[m, perm(d)], then the PV MMA is reading P from wrong TMEM addresses.
|
|
"""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
|
|
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
|
|
# Reuse the DiagVTruncIdKernel from test_diag_v_truncid.py
|
|
# (just run it and do more analysis on the output)
|
|
|
|
# Actually, let me just re-run the truncid test and do the permutation analysis in Python
|
|
# First run the kernel, then analyze
|
|
|
|
# We already ran it and have the results. Let me just do the analysis with the numbers we have.
|
|
# O[0,:5] = [6.0625, 11.875, -9.5625, -4.6875, -14.9375]
|
|
# ref[0,:5] = [6.0625, 10.5625, 11.875, -11.75, -9.5625]
|
|
# P[0] (full Q@K^T row 0) needs to be computed
|
|
|
|
torch.manual_seed(42)
|
|
m, n, head_dim = 128, 128, 64
|
|
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
qf = q[:,:,0].float()
|
|
kf = k[:,:,0].float()
|
|
P = (qf @ kf.T) # (128, 128) — the P matrix
|
|
|
|
# Now check: does O[0, d] = P[0, perm(d)] for some permutation?
|
|
# O[0,0] = 6.0625 → matches P[0,0] = 6.0625
|
|
# O[0,1] = 11.875 → matches P[0,2] = 11.875
|
|
# O[0,2] = -9.5625 → matches P[0,4] = -9.5625
|
|
# So O[0, d] = P[0, 2*d]? Let me check more.
|
|
|
|
O_row0 = [6.0625, 11.875, -9.5625, -4.6875, -14.9375]
|
|
P_row0 = P[0, :10].tolist()
|
|
print(f"P[0, :10] = {P_row0}")
|
|
print(f"O[0, :5] = {O_row0}")
|
|
|
|
# Check: O[0, d] = P[0, 2*d]?
|
|
for d in range(5):
|
|
print(f" O[0,{d}] = {O_row0[d]:.4f}, P[0,{2*d}] = {P_row0[2*d]:.4f}, match = {abs(O_row0[d] - P_row0[2*d]) < 0.01}")
|
|
|
|
# Also check full P row 0 vs O
|
|
# We can't get O without running the kernel again, but the pattern is clear:
|
|
# O[m, d] = P[m, 2*d] for the truncated identity V case
|
|
# This means the PV MMA is reading P from every other TMEM column
|
|
|
|
# Why 2*d? Because with (128,64) MMA, the A fragment reads TMEM with stride 2 in the K dimension.
|
|
# The (128,64,16) MMA atom has N=64, which means it reads 64 columns of P per K-tile
|
|
# But P has 128 columns. The MMA reads the first 64, but with the wrong stride.
|
|
#
|
|
# Actually, with (128,64,16) MMA:
|
|
# - A operand: (M=128, K=128) → MMA reads 128/16 = 8 K-tiles
|
|
# - Each K-tile reads P[:, k*16:(k+1)*16] = 16 columns of P
|
|
# - The A fragment for K-tile kb reads from TMEM column offset based on N_MMA
|
|
#
|
|
# The (128,64,16) MMA's TMEM A fragment layout might be:
|
|
# (128, N_MMA) where N_MMA relates to the N dimension of the MMA
|
|
# If N_MMA = 64 (half of 128), then P's 128 BF16 values in K are stored
|
|
# in 128 BF16 TMEM columns = 64 FP32 TMEM columns
|
|
# But the (128,64,16) A fragment might only address 32 FP32 TMEM columns
|
|
# because the MMA only uses 64 columns for the C output
|
|
# So P's 128 K values don't fit in 32 TMEM columns, and the layout is different
|
|
|
|
# The root cause: the (128,64) MMA's A fragment in TMEM packs 128 BF16 K values
|
|
# into fewer TMEM columns than the (128,128) MMA. The softmax packing writes P
|
|
# using the (128,128) layout, but the PV MMA reads with the (128,64) layout.
|
|
|
|
print("\n=== HYPOTHESIS ===")
|
|
print("The (128,64,16) MMA atom reads P from TMEM with a DIFFERENT layout")
|
|
print("than the softmax packing writes P with (QK C fragment layout).")
|
|
print("The (128,128,16) MMA atom's A fragment layout matches the QK C fragment layout,")
|
|
print("so the 128x128 case works. The (128,64,16) layout differs, causing the bug.")
|
|
print("Fix: softmax packing should write P using the PV MMA's A fragment layout.")
|