198 lines
10 KiB
Python
198 lines
10 KiB
Python
"""
|
|
D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts.
|
|
Runs layout analysis inside a JIT-compiled function to get MLIR context.
|
|
"""
|
|
import torch, math, sys
|
|
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05, cpasync
|
|
from cutlass import Float32, BFloat16, Int32, const_expr
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cutlass.torch as ct
|
|
|
|
|
|
def main():
|
|
for hd in [64, 256]:
|
|
print(f"\n{'='*60}")
|
|
print(f" HEAD_DIM={hd}")
|
|
print(f"{'='*60}")
|
|
|
|
m = 128; s_k = 128
|
|
use_smem_p = hd > 64
|
|
pv_n_tile = min(hd, 256)
|
|
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# Use FmhaKernel's actual __call__ path to compute layouts
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
kern = FmhaKernel(head_dim=hd, s_k=s_k)
|
|
|
|
# Trigger _setup by calling compile with a dummy kernel
|
|
# Actually, we can access the _setup by calling __call__ on the kernel
|
|
# But that runs the full kernel. Let me manually call _setup by
|
|
# building the MMAs the same way __call__ does.
|
|
|
|
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
|
|
|
|
# Create a minimal analysis kernel that just prints shapes
|
|
@cute.kernel
|
|
def _diag_kernel(_q, _k, _v, _c):
|
|
# Build QK MMA
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, a_major, b_major, Float32,
|
|
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
|
|
)
|
|
|
|
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
|
|
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
|
|
c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode()
|
|
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, pv_a_major, c_major, Float32,
|
|
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
|
|
)
|
|
|
|
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (128, 128, qk_ik * 4)
|
|
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
|
|
|
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
|
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
|
print(f" QK tiler: {qk_mma_tiler}")
|
|
print(f" PV tiler: {pv_mma_tiler}")
|
|
print(f" use_smem_p: {use_smem_p}")
|
|
|
|
# QK C-fragment
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_as)
|
|
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
|
|
|
|
# PV C-fragment
|
|
pv_thr = pv_mma.get_slice(0)
|
|
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
|
tOtO = pv_thr.make_fragment_C(pv_as)
|
|
print(f" tOtO shape: {cute.shape(tOtO)} layout: {tOtO.layout}")
|
|
|
|
# PV A-operand SMEM layout (sP)
|
|
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
|
# p_smem_s.outer is the layout. Use it directly for shape queries.
|
|
# Can't actually allocate SMEM here (that's the SmemAllocator's job in the kernel)
|
|
# But we can create a tensor with a dummy pointer for shape analysis.
|
|
sP_alloc = cute.make_tensor(0, p_smem_s.outer)
|
|
print(f" sP shape: {cute.shape(sP_alloc)}")
|
|
sP_2d = cute.group_modes(sP_alloc, 0, 3)
|
|
print(f" sP_2d shape: {cute.shape(sP_2d)} rank={len(cute.shape(sP_2d))}")
|
|
print(f" sP_2d layout: {sP_2d.layout}")
|
|
print(f" sP_2d layout: {sP_2d.layout}")
|
|
|
|
# PV A-operand fragments — can't create with dummy ptr 0 (no memspace)
|
|
# Just print the layout info from the SMEM layout
|
|
print(f" p_smem_s.outer (sP layout): {p_smem_s.outer}")
|
|
print(f" p_smem_s.inner (swizzle): {p_smem_s.inner}")
|
|
|
|
# What we need: the PV MMA's A-operand thread partition for sP
|
|
# This tells us which threads read which sP elements during PV GEMM
|
|
# The softmax warps must WRITE to sP using the same mapping
|
|
# (so that the MMA warp can READ using its own partition)
|
|
|
|
# Softmax TMEM load partition
|
|
sfw_idx = 0
|
|
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
|
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
|
|
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
|
tTMEM_LOADtS = thr_load.partition_S(tStS)
|
|
cS = cute.make_identity_tensor((128, 128))
|
|
tScS = qk_thr.partition_C(cS)
|
|
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
|
print(f" tTMEM_LOADcS (coord) shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
|
|
|
|
# P sub-layout in TMEM
|
|
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
|
|
print(f" p_cols_fp32: {p_cols_fp32}")
|
|
|
|
# make_tiled_copy_C with qk_mma
|
|
print(f"\n --- make_tiled_copy_C(qk_mma) ---")
|
|
try:
|
|
smem_copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), BFloat16, num_bits_per_copy=128)
|
|
tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
|
thr_copy = tiled_copy.get_slice(sfw_idx)
|
|
print(f" OK, created tiled_copy")
|
|
tD = thr_copy.partition_D(sP_2d)
|
|
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
|
|
|
|
# Create source tensor using the QK C-fragment register layout
|
|
# This is what make_tiled_copy_C expects as source
|
|
qk_C_reg = qk_thr.make_fragment_C(qk_as) # register fragment
|
|
# But we need a 2D version (without stage mode)
|
|
qk_C_2d = cute.group_modes(qk_C_reg, 0, 2)
|
|
print(f" qk_C_reg shape: {cute.shape(qk_C_reg)} layout: {qk_C_reg.layout}")
|
|
print(f" qk_C_2d shape: {cute.shape(qk_C_2d)} layout: {qk_C_2d.layout}")
|
|
|
|
try:
|
|
tS = thr_copy.partition_S(qk_C_2d)
|
|
print(f" partition_S(qk_C_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
|
print(f" partition_S and partition_D RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
|
|
except Exception as e:
|
|
print(f" partition_S(qk_C_2d) FAILED: {e}")
|
|
# Try with a P-sized sub-tensor of the QK C-fragment
|
|
# P is only the first p_cols_fp32 columns of S
|
|
p_cols_bf16 = pv_mma_tiler[2] # K dimension of PV = columns of P
|
|
print(f" p_cols_bf16 (PV tiler[2]): {p_cols_bf16}")
|
|
qk_P_shape = (128, p_cols_bf16) # P matrix is (128, p_cols_bf16)
|
|
qk_P_layout = cute.composition(qk_C_reg.layout, cute.make_layout(qk_P_shape))
|
|
qk_P = cute.make_tensor(0, qk_P_layout)
|
|
qk_P_2d = cute.group_modes(qk_P, 0, 2)
|
|
print(f" qk_P_2d shape: {cute.shape(qk_P_2d)} layout: {qk_P_2d.layout}")
|
|
try:
|
|
tS = thr_copy.partition_S(qk_P_2d)
|
|
print(f" partition_S(qk_P_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
|
print(f" RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
|
|
except Exception as e2:
|
|
print(f" partition_S(qk_P_2d) FAILED: {e2}")
|
|
except Exception as e:
|
|
print(f" make_tiled_copy_C FAILED: {e}")
|
|
|
|
# make_tiled_copy_C with pv_mma
|
|
print(f"\n --- make_tiled_copy_C(pv_mma) ---")
|
|
try:
|
|
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
|
|
thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx)
|
|
tD_pv = thr_copy_pv.partition_D(sP_2d)
|
|
print(f" OK, partition_D(sP_2d) shape: {cute.shape(tD_pv)} rank={len(cute.shape(tD_pv))}")
|
|
except Exception as e:
|
|
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
|
|
|
|
# KEY QUESTION: what is the relationship between QK C-fragment registers
|
|
# and the TMEM load partition? Are they the same threads, same values?
|
|
print(f"\n --- QK C-fragment vs TMEM load partition ---")
|
|
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
|
|
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
|
|
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
|
|
print(f" size(tStS) = {cute.size(tStS)}, size(tTMEM_LOADcS) = {cute.size(tTMEM_LOADcS)}")
|
|
# Do the C-fragment and TMEM load have the same per-thread element count?
|
|
# If so, the same threads own the same elements in both views.
|
|
|
|
# Compile and run
|
|
import cuda.bindings.driver as cuda
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v.unsqueeze(-1)).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v.unsqueeze(-1)))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
|
|
print(' Compiling diagnostic...', flush=True)
|
|
compiled = cute.compile(_diag_kernel, mQ, mK, mV, mC)
|
|
compiled(mQ, mK, mV, mC)
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|