D1.3: Enhanced diagnostic - test QK C-fragment as source for make_tiled_copy_C

This commit is contained in:
2026-05-23 22:24:15 +00:00
parent a90fe41b6b
commit d10cab7a8e

View File

@@ -123,27 +123,38 @@ def main():
tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_copy = tiled_copy.get_slice(sfw_idx)
print(f" OK, created tiled_copy")
tD = thr_copy.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
# Create source tensor using the QK C-fragment register layout
# This is what make_tiled_copy_C expects as source
qk_C_reg = qk_thr.make_fragment_C(qk_as) # register fragment
# But we need a 2D version (without stage mode)
qk_C_2d = cute.group_modes(qk_C_reg, 0, 2)
print(f" qk_C_reg shape: {cute.shape(qk_C_reg)} layout: {qk_C_reg.layout}")
print(f" qk_C_2d shape: {cute.shape(qk_C_2d)} layout: {qk_C_2d.layout}")
try:
tD = thr_copy.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
tS = thr_copy.partition_S(qk_C_2d)
print(f" partition_S(qk_C_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
print(f" partition_S and partition_D RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
except Exception as e:
print(f" partition_D(sP_2d) FAILED: {e}")
# Try source partition with coord layout
try:
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
tS = thr_copy.partition_S(rP_coord)
print(f" partition_S(coord) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
except Exception as e:
print(f" partition_S(coord) FAILED: {e}")
# Try source with QK C-frag layout
try:
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
tS = thr_copy.partition_S(rS=rP_qk)
print(f" partition_S(QK C-frag) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
except Exception as e:
print(f" partition_S(QK C-frag) FAILED: {e}")
print(f" partition_S(qk_C_2d) FAILED: {e}")
# Try with a P-sized sub-tensor of the QK C-fragment
# P is only the first p_cols_fp32 columns of S
p_cols_bf16 = pv_mma_tiler[2] # K dimension of PV = columns of P
print(f" p_cols_bf16 (PV tiler[2]): {p_cols_bf16}")
qk_P_shape = (128, p_cols_bf16) # P matrix is (128, p_cols_bf16)
qk_P_layout = cute.composition(qk_C_reg.layout, cute.make_layout(qk_P_shape))
qk_P = cute.make_tensor(0, qk_P_layout)
qk_P_2d = cute.group_modes(qk_P, 0, 2)
print(f" qk_P_2d shape: {cute.shape(qk_P_2d)} layout: {qk_P_2d.layout}")
try:
tS = thr_copy.partition_S(qk_P_2d)
print(f" partition_S(qk_P_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
print(f" RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
except Exception as e2:
print(f" partition_S(qk_P_2d) FAILED: {e2}")
except Exception as e:
print(f" make_tiled_copy_C FAILED: {e}")
@@ -152,15 +163,21 @@ def main():
try:
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx)
print(f" OK, created tiled_copy_pv")
try:
tD = thr_copy_pv.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
except Exception as e:
print(f" partition_D(sP_2d) FAILED: {e}")
tD_pv = thr_copy_pv.partition_D(sP_2d)
print(f" OK, partition_D(sP_2d) shape: {cute.shape(tD_pv)} rank={len(cute.shape(tD_pv))}")
except Exception as e:
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
# KEY QUESTION: what is the relationship between QK C-fragment registers
# and the TMEM load partition? Are they the same threads, same values?
print(f"\n --- QK C-fragment vs TMEM load partition ---")
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
print(f" size(tStS) = {cute.size(tStS)}, size(tTMEM_LOADcS) = {cute.size(tTMEM_LOADcS)}")
# Do the C-fragment and TMEM load have the same per-thread element count?
# If so, the same threads own the same elements in both views.
# Compile and run
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)