diff --git a/tests/unit/test_d1_3_layout_diag.py b/tests/unit/test_d1_3_layout_diag.py index 756d0ff6..b23874f7 100644 --- a/tests/unit/test_d1_3_layout_diag.py +++ b/tests/unit/test_d1_3_layout_diag.py @@ -123,27 +123,38 @@ def main(): tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) thr_copy = tiled_copy.get_slice(sfw_idx) print(f" OK, created tiled_copy") + tD = thr_copy.partition_D(sP_2d) + print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}") + + # Create source tensor using the QK C-fragment register layout + # This is what make_tiled_copy_C expects as source + qk_C_reg = qk_thr.make_fragment_C(qk_as) # register fragment + # But we need a 2D version (without stage mode) + qk_C_2d = cute.group_modes(qk_C_reg, 0, 2) + print(f" qk_C_reg shape: {cute.shape(qk_C_reg)} layout: {qk_C_reg.layout}") + print(f" qk_C_2d shape: {cute.shape(qk_C_2d)} layout: {qk_C_2d.layout}") + try: - tD = thr_copy.partition_D(sP_2d) - print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}") + tS = thr_copy.partition_S(qk_C_2d) + print(f" partition_S(qk_C_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}") + print(f" partition_S and partition_D RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}") except Exception as e: - print(f" partition_D(sP_2d) FAILED: {e}") - - # Try source partition with coord layout - try: - rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout) - tS = thr_copy.partition_S(rP_coord) - print(f" partition_S(coord) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}") - except Exception as e: - print(f" partition_S(coord) FAILED: {e}") - - # Try source with QK C-frag layout - try: - rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout) - tS = thr_copy.partition_S(rS=rP_qk) - print(f" partition_S(QK C-frag) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}") - except Exception as e: - print(f" partition_S(QK C-frag) FAILED: {e}") + print(f" partition_S(qk_C_2d) FAILED: {e}") + # Try with a P-sized sub-tensor of the QK C-fragment + # P is only the first p_cols_fp32 columns of S + p_cols_bf16 = pv_mma_tiler[2] # K dimension of PV = columns of P + print(f" p_cols_bf16 (PV tiler[2]): {p_cols_bf16}") + qk_P_shape = (128, p_cols_bf16) # P matrix is (128, p_cols_bf16) + qk_P_layout = cute.composition(qk_C_reg.layout, cute.make_layout(qk_P_shape)) + qk_P = cute.make_tensor(0, qk_P_layout) + qk_P_2d = cute.group_modes(qk_P, 0, 2) + print(f" qk_P_2d shape: {cute.shape(qk_P_2d)} layout: {qk_P_2d.layout}") + try: + tS = thr_copy.partition_S(qk_P_2d) + print(f" partition_S(qk_P_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}") + print(f" RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}") + except Exception as e2: + print(f" partition_S(qk_P_2d) FAILED: {e2}") except Exception as e: print(f" make_tiled_copy_C FAILED: {e}") @@ -152,15 +163,21 @@ def main(): try: tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma) thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx) - print(f" OK, created tiled_copy_pv") - try: - tD = thr_copy_pv.partition_D(sP_2d) - print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}") - except Exception as e: - print(f" partition_D(sP_2d) FAILED: {e}") + tD_pv = thr_copy_pv.partition_D(sP_2d) + print(f" OK, partition_D(sP_2d) shape: {cute.shape(tD_pv)} rank={len(cute.shape(tD_pv))}") except Exception as e: print(f" make_tiled_copy_C(pv_mma) FAILED: {e}") + # KEY QUESTION: what is the relationship between QK C-fragment registers + # and the TMEM load partition? Are they the same threads, same values? + print(f"\n --- QK C-fragment vs TMEM load partition ---") + print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}") + print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}") + print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}") + print(f" size(tStS) = {cute.size(tStS)}, size(tTMEM_LOADcS) = {cute.size(tTMEM_LOADcS)}") + # Do the C-fragment and TMEM load have the same per-thread element count? + # If so, the same threads own the same elements in both views. + # Compile and run import cuda.bindings.driver as cuda stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)