D1.3: Enhanced diagnostic - test QK C-fragment as source for make_tiled_copy_C
This commit is contained in:
@@ -123,27 +123,38 @@ def main():
|
||||
tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_copy = tiled_copy.get_slice(sfw_idx)
|
||||
print(f" OK, created tiled_copy")
|
||||
tD = thr_copy.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
|
||||
|
||||
# Create source tensor using the QK C-fragment register layout
|
||||
# This is what make_tiled_copy_C expects as source
|
||||
qk_C_reg = qk_thr.make_fragment_C(qk_as) # register fragment
|
||||
# But we need a 2D version (without stage mode)
|
||||
qk_C_2d = cute.group_modes(qk_C_reg, 0, 2)
|
||||
print(f" qk_C_reg shape: {cute.shape(qk_C_reg)} layout: {qk_C_reg.layout}")
|
||||
print(f" qk_C_2d shape: {cute.shape(qk_C_2d)} layout: {qk_C_2d.layout}")
|
||||
|
||||
try:
|
||||
tD = thr_copy.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
|
||||
tS = thr_copy.partition_S(qk_C_2d)
|
||||
print(f" partition_S(qk_C_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
||||
print(f" partition_S and partition_D RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
|
||||
except Exception as e:
|
||||
print(f" partition_D(sP_2d) FAILED: {e}")
|
||||
|
||||
# Try source partition with coord layout
|
||||
try:
|
||||
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
|
||||
tS = thr_copy.partition_S(rP_coord)
|
||||
print(f" partition_S(coord) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(coord) FAILED: {e}")
|
||||
|
||||
# Try source with QK C-frag layout
|
||||
try:
|
||||
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
|
||||
tS = thr_copy.partition_S(rS=rP_qk)
|
||||
print(f" partition_S(QK C-frag) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(QK C-frag) FAILED: {e}")
|
||||
print(f" partition_S(qk_C_2d) FAILED: {e}")
|
||||
# Try with a P-sized sub-tensor of the QK C-fragment
|
||||
# P is only the first p_cols_fp32 columns of S
|
||||
p_cols_bf16 = pv_mma_tiler[2] # K dimension of PV = columns of P
|
||||
print(f" p_cols_bf16 (PV tiler[2]): {p_cols_bf16}")
|
||||
qk_P_shape = (128, p_cols_bf16) # P matrix is (128, p_cols_bf16)
|
||||
qk_P_layout = cute.composition(qk_C_reg.layout, cute.make_layout(qk_P_shape))
|
||||
qk_P = cute.make_tensor(0, qk_P_layout)
|
||||
qk_P_2d = cute.group_modes(qk_P, 0, 2)
|
||||
print(f" qk_P_2d shape: {cute.shape(qk_P_2d)} layout: {qk_P_2d.layout}")
|
||||
try:
|
||||
tS = thr_copy.partition_S(qk_P_2d)
|
||||
print(f" partition_S(qk_P_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
||||
print(f" RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
|
||||
except Exception as e2:
|
||||
print(f" partition_S(qk_P_2d) FAILED: {e2}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C FAILED: {e}")
|
||||
|
||||
@@ -152,15 +163,21 @@ def main():
|
||||
try:
|
||||
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
|
||||
thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx)
|
||||
print(f" OK, created tiled_copy_pv")
|
||||
try:
|
||||
tD = thr_copy_pv.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
|
||||
except Exception as e:
|
||||
print(f" partition_D(sP_2d) FAILED: {e}")
|
||||
tD_pv = thr_copy_pv.partition_D(sP_2d)
|
||||
print(f" OK, partition_D(sP_2d) shape: {cute.shape(tD_pv)} rank={len(cute.shape(tD_pv))}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
|
||||
|
||||
# KEY QUESTION: what is the relationship between QK C-fragment registers
|
||||
# and the TMEM load partition? Are they the same threads, same values?
|
||||
print(f"\n --- QK C-fragment vs TMEM load partition ---")
|
||||
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
|
||||
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
|
||||
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
|
||||
print(f" size(tStS) = {cute.size(tStS)}, size(tTMEM_LOADcS) = {cute.size(tTMEM_LOADcS)}")
|
||||
# Do the C-fragment and TMEM load have the same per-thread element count?
|
||||
# If so, the same threads own the same elements in both views.
|
||||
|
||||
# Compile and run
|
||||
import cuda.bindings.driver as cuda
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
Reference in New Issue
Block a user