D1.3: Layout diagnostic v2 - run inside JIT-compiled kernel
This commit is contained in:
@@ -1,241 +1,182 @@
|
||||
"""
|
||||
D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts.
|
||||
|
||||
Goal: Understand the exact QK C-fragment partition (source) and
|
||||
PV A-operand SMEM layout (destination) so we can write a proper
|
||||
register→SMEM copy for P values.
|
||||
|
||||
This test ONLY prints shapes. No kernel execution. Pure layout analysis.
|
||||
Uses FmhaKernel._setup to compute layouts the same way the kernel does.
|
||||
Runs layout analysis inside a JIT-compiled function to get MLIR context.
|
||||
"""
|
||||
import torch, math
|
||||
import torch, math, sys
|
||||
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05, cpasync
|
||||
from cutlass import Float32, BFloat16, Int32
|
||||
from cutlass import Float32, BFloat16, Int32, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cutlass.torch as ct
|
||||
|
||||
|
||||
def analyze_layouts(hd, s_k=128):
|
||||
print(f"\n{'='*60}")
|
||||
print(f" HEAD_DIM={hd}, s_k={s_k}")
|
||||
print(f"{'='*60}")
|
||||
def main():
|
||||
for hd in [64, 256]:
|
||||
print(f"\n{'='*60}")
|
||||
print(f" HEAD_DIM={hd}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
m = 128
|
||||
use_smem_p = hd > 64
|
||||
pv_n_tile = min(hd, 256)
|
||||
n_kv_tiles = s_k // 128
|
||||
m = 128; s_k = 128
|
||||
use_smem_p = hd > 64
|
||||
pv_n_tile = min(hd, 256)
|
||||
|
||||
# Create dummy tensors just for layout computation
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
|
||||
c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode()
|
||||
# Use FmhaKernel's actual __call__ path to compute layouts
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
kern = FmhaKernel(head_dim=hd, s_k=s_k)
|
||||
|
||||
# Build QK MMA
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
# Trigger _setup by calling compile with a dummy kernel
|
||||
# Actually, we can access the _setup by calling __call__ on the kernel
|
||||
# But that runs the full kernel. Let me manually call _setup by
|
||||
# building the MMAs the same way __call__ does.
|
||||
|
||||
# Build PV MMA (SMEM-P for hd>64, TMEM-P for hd<=64)
|
||||
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
|
||||
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, pv_a_major, c_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
|
||||
)
|
||||
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
|
||||
|
||||
# Tiler computation (same as FmhaKernel._setup)
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
# Create a minimal analysis kernel that just prints shapes
|
||||
@cute.kernel
|
||||
def _diag_kernel(_q, _k, _v, _c):
|
||||
# Build QK MMA
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
|
||||
print(f"\n--- MMA info ---")
|
||||
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f" QK tiler: {qk_mma_tiler}")
|
||||
print(f" PV tiler: {pv_mma_tiler}")
|
||||
print(f" use_smem_p: {use_smem_p}")
|
||||
print(f" PV A source: {pv_source}")
|
||||
print(f" PV A major: {pv_a_major}")
|
||||
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
|
||||
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
|
||||
c_major = LayoutEnum.from_tensor(ct.from_dlpack(_c)).mma_major_mode()
|
||||
|
||||
# QK C-fragment (where S lives in TMEM, where softmax reads P from)
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
print(f"\n--- QK C-fragment (tStS = S and P in TMEM) ---")
|
||||
print(f" qk_as (partition shape): {qk_as}")
|
||||
print(f" tStS shape: {cute.shape(tStS)}")
|
||||
print(f" tStS size: {cute.size(tStS)}")
|
||||
print(f" tStS layout: {tStS.layout}")
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, pv_a_major, c_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
|
||||
)
|
||||
|
||||
# PV C-fragment (where O lives in TMEM)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
print(f"\n--- PV C-fragment (tOtO = O in TMEM) ---")
|
||||
print(f" pv_as (partition shape): {pv_as}")
|
||||
print(f" tOtO shape: {cute.shape(tOtO)}")
|
||||
print(f" tOtO size: {cute.size(tOtO)}")
|
||||
print(f" tOtO layout: {tOtO.layout}")
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
# PV A-operand SMEM layout (where P must be written for SMEM-P)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
print(f"\n--- PV A-operand SMEM layout (sP) ---")
|
||||
print(f" p_smem_s outer: {p_smem_s.outer}")
|
||||
print(f" p_smem_s inner (swizzle): {p_smem_s.inner}")
|
||||
outer_layout = cute.make_layout(p_smem_s.outer)
|
||||
print(f" p_smem_s outer shape: {cute.shape(outer_layout)}")
|
||||
print(f" p_smem_s outer size: {cute.size(outer_layout)}")
|
||||
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f" QK tiler: {qk_mma_tiler}")
|
||||
print(f" PV tiler: {pv_mma_tiler}")
|
||||
print(f" use_smem_p: {use_smem_p}")
|
||||
|
||||
# PV A-operand fragments (what MMA reads from TMEM or SMEM)
|
||||
# For TMEM-P: tOrP reads from TMEM
|
||||
# For SMEM-P: tCrP reads from SMEM, tOrP also reads from SMEM
|
||||
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
|
||||
# QK C-fragment
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
|
||||
|
||||
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
|
||||
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
|
||||
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
|
||||
# PV C-fragment
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
print(f" tOtO shape: {cute.shape(tOtO)} layout: {tOtO.layout}")
|
||||
|
||||
print(f"\n--- PV A-operand fragments ---")
|
||||
print(f" tOrP_tmem (TMEM-P A-read): shape={cute.shape(tOrP_tmem)}")
|
||||
print(f" tCrP_smem (SMEM-P A-load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}")
|
||||
print(f" tOrP_smem (SMEM-P A-read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}")
|
||||
print(f" tOrP_tmem size: {cute.size(tOrP_tmem)}")
|
||||
print(f" tCrP_smem size: {cute.size(tCrP_smem)}")
|
||||
# PV A-operand SMEM layout (sP)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
|
||||
print(f" sP shape: {cute.shape(sP_alloc)}")
|
||||
sP_2d = cute.group_modes(sP_alloc, 0, 3)
|
||||
print(f" sP_2d shape: {cute.shape(sP_2d)} rank={len(cute.shape(sP_2d))}")
|
||||
print(f" sP_2d layout: {sP_2d.layout}")
|
||||
|
||||
# Slice tOrP the way the kernel does: [(None,None,None,0)]
|
||||
tOrP_base_tmem = pv_thr.make_fragment_A(tP_tmem)
|
||||
tOrP_tmem_sliced = tOrP_base_tmem[(None,None,None,0)]
|
||||
tOrP_base_smem = pv_thr.make_fragment_A(tP_smem)
|
||||
tOrP_smem_sliced = tOrP_base_smem[(None,None,None,0)]
|
||||
print(f"\n--- PV A-operand after slice [(None,None,None,0)] ---")
|
||||
print(f" tOrP_tmem_sliced: shape={cute.shape(tOrP_tmem_sliced)}")
|
||||
print(f" tOrP_smem_sliced: shape={cute.shape(tOrP_smem_sliced)}")
|
||||
# PV A-operand fragments
|
||||
tP_tmem = cute.make_tensor(tStS.iterator, p_smem_s.outer) # dummy tmem ptr
|
||||
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
|
||||
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
|
||||
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
|
||||
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
|
||||
print(f" tOrP_tmem shape: {cute.shape(tOrP_tmem)}")
|
||||
print(f" tCrP_smem shape: {cute.shape(tCrP_smem)} layout: {tCrP_smem.layout}")
|
||||
print(f" tOrP_smem shape: {cute.shape(tOrP_smem)}")
|
||||
|
||||
# Softmax thread partition of S (via TMEM load atoms)
|
||||
sfw_idx = 0
|
||||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS)
|
||||
cS = cute.make_identity_tensor((128, 128))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
# Slice as kernel does
|
||||
tOrP_tmem_s = tOrP_tmem[(None,None,None,0)]
|
||||
tOrP_smem_s = tOrP_smem[(None,None,None,0)]
|
||||
print(f" tOrP_tmem sliced: {cute.shape(tOrP_tmem_s)}")
|
||||
print(f" tOrP_smem sliced: {cute.shape(tOrP_smem_s)}")
|
||||
|
||||
print(f"\n--- Softmax thread 0: TMEM load partition ---")
|
||||
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
|
||||
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
|
||||
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
|
||||
print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S/P coordinates")
|
||||
# Softmax TMEM load partition
|
||||
sfw_idx = 0
|
||||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS)
|
||||
cS = cute.make_identity_tensor((128, 128))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
print(f" tTMEM_LOADcS (coord) shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
|
||||
|
||||
# P register bridge: what P values does each softmax thread have?
|
||||
# After softmax, P is in rP_bf16 which shares layout with rS
|
||||
# rP_bf16 = make_tensor(recast_ptr(rP_words.iterator, dtype=BF16), tTMEM_LOADrS.layout)
|
||||
# where tTMEM_LOADrS = make_rmem_tensor(tTMEM_LOADcS.shape, FP32)
|
||||
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((pv_mma_tiler[0], p_cols_fp32)))
|
||||
print(f"\n--- P columns in TMEM ---")
|
||||
print(f" p_cols_fp32 (PV tiler[2] * BF16/FP32): {p_cols_fp32}")
|
||||
print(f" P TMEM sub-layout (composition of tStS): {tStP_layout}")
|
||||
print(f" P TMEM sub-layout shape: {cute.shape(cute.make_layout(tStP_layout))}")
|
||||
# P sub-layout in TMEM
|
||||
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
|
||||
print(f" p_cols_fp32: {p_cols_fp32}")
|
||||
|
||||
# Now the critical question: what's the sP shape that the softmax needs to write to?
|
||||
sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
|
||||
sP_2d = cute.group_modes(sP_alloc, 0, 3)
|
||||
print(f"\n--- sP for SMEM-P write target ---")
|
||||
print(f" sP shape (3D): {cute.shape(sP_alloc)}")
|
||||
print(f" sP_2d shape (grouped): {cute.shape(sP_2d)}")
|
||||
print(f" sP_2d rank: {len(cute.shape(sP_2d))}")
|
||||
# make_tiled_copy_C with qk_mma
|
||||
print(f"\n --- make_tiled_copy_C(qk_mma) ---")
|
||||
try:
|
||||
smem_copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), BFloat16, num_bits_per_copy=128)
|
||||
tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_copy = tiled_copy.get_slice(sfw_idx)
|
||||
print(f" OK, created tiled_copy")
|
||||
try:
|
||||
tD = thr_copy.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
|
||||
except Exception as e:
|
||||
print(f" partition_D(sP_2d) FAILED: {e}")
|
||||
|
||||
# Try make_tiled_copy_C with QK MMA (the approach that gave rank mismatch)
|
||||
print(f"\n--- make_tiled_copy_C(smem_copy_atom, qk_mma) ---")
|
||||
try:
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
BFloat16,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_copy_qk = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_qk = tiled_copy_qk.get_slice(sfw_idx)
|
||||
tSMEM_QKsP = thr_qk.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_QKsP)} rank={len(cute.shape(tSMEM_QKsP))}")
|
||||
# Try source partition with coord layout
|
||||
try:
|
||||
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
|
||||
tS = thr_copy.partition_S(rP_coord)
|
||||
print(f" partition_S(coord) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(coord) FAILED: {e}")
|
||||
|
||||
# Try partition_S with various source layouts
|
||||
# 1. Identity tensor (coordinate) layout
|
||||
try:
|
||||
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
|
||||
tSMEM_QKrP = thr_qk.partition_S(rP_coord)
|
||||
print(f" partition_S(coord layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(coord layout) FAILED: {e}")
|
||||
# Try source with QK C-frag layout
|
||||
try:
|
||||
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
|
||||
tS = thr_copy.partition_S(rS=rP_qk)
|
||||
print(f" partition_S(QK C-frag) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(QK C-frag) FAILED: {e}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C FAILED: {e}")
|
||||
|
||||
# 2. QK C-fragment layout (tStS)
|
||||
try:
|
||||
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
|
||||
tSMEM_QKrP = thr_qk.partition_S(rP_qk)
|
||||
print(f" partition_S(QK C-frag layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(QK C-frag layout) FAILED: {e}")
|
||||
# make_tiled_copy_C with pv_mma
|
||||
print(f"\n --- make_tiled_copy_C(pv_mma) ---")
|
||||
try:
|
||||
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
|
||||
thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx)
|
||||
print(f" OK, created tiled_copy_pv")
|
||||
try:
|
||||
tD = thr_copy_pv.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
|
||||
except Exception as e:
|
||||
print(f" partition_D(sP_2d) FAILED: {e}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
|
||||
|
||||
# 3. The rmem layout that softmax actually has P in
|
||||
rP_rmem_layout = cute.make_layout(tTMEM_LOADcS.shape)
|
||||
try:
|
||||
rP_rmem = cute.make_tensor(cute.find(rP_rmem_layout, 0), rP_rmem_layout)
|
||||
tSMEM_QKrP = thr_qk.partition_S(rP_rmem)
|
||||
print(f" partition_S(rmem layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S(rmem layout) FAILED: {e}")
|
||||
# Compile and run
|
||||
import cuda.bindings.driver as cuda
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C(qk_mma) FAILED: {e}")
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v.unsqueeze(-1)).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v.unsqueeze(-1)))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
|
||||
# Try make_tiled_copy_C with PV MMA (different approach)
|
||||
print(f"\n--- make_tiled_copy_C(smem_copy_atom, pv_mma) ---")
|
||||
try:
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
BFloat16,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
|
||||
thr_pv = tiled_copy_pv.get_slice(sfw_idx)
|
||||
tSMEM_PVsP = thr_pv.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_PVsP)} rank={len(cute.shape(tSMEM_PVsP))}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
|
||||
|
||||
# Try make_tiled_copy (not _C) with explicit M/N layout
|
||||
print(f"\n--- make_tiled_copy (manual M/N) ---")
|
||||
try:
|
||||
# The P matrix is (128, n_kv) = (M=128, N_P=128)
|
||||
# Use (128, 128) tile for the copy
|
||||
MN = cute.make_layout((128, 128))
|
||||
smem_copy_atom2 = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
BFloat16,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
# Need to use make_tiled_copy with proper thread arrangement
|
||||
# This is for the softmax warps (128 threads = 4 warps)
|
||||
thr_layout = cute.make_layout((4, 32)) # 4 warps × 32 lanes
|
||||
val_layout = cute.make_layout((1, 1)) # each thread copies 1 element
|
||||
tiled_manual = cute.make_tiled_copy(smem_copy_atom2, thr_layout, val_layout)
|
||||
thr_manual = tiled_manual.get_slice(sfw_idx)
|
||||
tMANUALsP = thr_manual.partition_D(sP_2d)
|
||||
print(f" partition_D(sP_2d) shape: {cute.shape(tMANUALsP)} rank={len(cute.shape(tMANUALsP))}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy (manual) FAILED: {e}")
|
||||
print(' Compiling diagnostic...', flush=True)
|
||||
compiled = cute.compile(_diag_kernel, mQ, mK, mV, mC)
|
||||
compiled(mQ, mK, mV, mC)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
analyze_layouts(64) # TMEM-P (works)
|
||||
analyze_layouts(256) # SMEM-P (broken)
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user