stuff and stuff
This commit is contained in:
@@ -39,10 +39,12 @@ class PvHeadDimKernel:
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
# Diagnostic: print V SMEM size
|
||||
v_s = cute.slice_(self.v_smem_s, (None,None,None,0))
|
||||
v_sz = cute.size_in_bytes(self.q_dtype, v_s)
|
||||
print(f"[DIAG] pv_mma_tiler={self.pv_mma_tiler} V SMEM per stage: {v_sz} bytes ({v_sz//2} BF16)")
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCrV = pv_mma.make_fragment_B(v_s)
|
||||
print(f"[DIAG] tCrV shape={tCrV.shape} size={cute.size(tCrV)} k_phases={cute.size(tCrV,mode=[2])}")
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
|
||||
Reference in New Issue
Block a user