96 lines
4.3 KiB
Python
96 lines
4.3 KiB
Python
|
|
"""Compare C-fragment composition layout vs A-fragment layout for PV P operand."""
|
||
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||
|
|
from cutlass.cute.nvgpu import tcgen05
|
||
|
|
from cutlass import Float32, BFloat16
|
||
|
|
from cutlass.utils import LayoutEnum
|
||
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||
|
|
import cuda.bindings.driver as cuda
|
||
|
|
import cutlass.torch as ct
|
||
|
|
|
||
|
|
|
||
|
|
class LayoutCompareKernel:
|
||
|
|
def __init__(self):
|
||
|
|
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||
|
|
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||
|
|
self.mma_tiler_mn = (128, 128)
|
||
|
|
self.cta_group = tcgen05.CtaGroup.ONE
|
||
|
|
self.threads_per_cta = 64 # minimal
|
||
|
|
|
||
|
|
@cute.jit
|
||
|
|
def __call__(self, q, k, v, c, stream):
|
||
|
|
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||
|
|
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||
|
|
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||
|
|
|
||
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||
|
|
self.q_dtype, self.q_dtype, a_major, b_major,
|
||
|
|
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||
|
|
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
|
||
|
|
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||
|
|
|
||
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||
|
|
qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||
|
|
pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1])
|
||
|
|
|
||
|
|
qk_thr = qk_mma.get_slice(0)
|
||
|
|
pv_thr = pv_mma.get_slice(0)
|
||
|
|
|
||
|
|
qk_acc_shape = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||
|
|
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||
|
|
|
||
|
|
pv_acc_shape = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
||
|
|
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||
|
|
|
||
|
|
# P A-fragment
|
||
|
|
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, self.q_dtype, 1)
|
||
|
|
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||
|
|
tOrP_base = pv_thr.make_fragment_A(tP)
|
||
|
|
tOrP = tOrP_base[(None, None, None, 0)]
|
||
|
|
|
||
|
|
# C-fragment composition layout
|
||
|
|
tilePlikeFP32 = qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||
|
|
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||
|
|
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout) # offset 32 FP32 columns
|
||
|
|
|
||
|
|
# With scaled offset for A-fragment
|
||
|
|
p_offset_in_a_elements = self.qk_acc_dtype.width // self.q_dtype.width * 32 # = 64
|
||
|
|
tOrP0 = cute.make_tensor(tOrP.iterator + p_offset_in_a_elements, tOrP.layout)
|
||
|
|
|
||
|
|
# Print layouts
|
||
|
|
cute.printf("tStS layout: {}", tStS.layout)
|
||
|
|
cute.printf("tOrP layout: {}", tOrP.layout)
|
||
|
|
cute.printf("tStS_P layout: {}", tStS_P_layout)
|
||
|
|
cute.printf("tOrP0 layout: {}", tOrP0.layout)
|
||
|
|
cute.printf("tOrP shape: {}", tOrP.shape)
|
||
|
|
cute.printf("tStS_P shape: {}", tStS_P.shape)
|
||
|
|
cute.printf("tOtO layout: {}", tOtO.layout)
|
||
|
|
cute.printf("pv_mma_tiler: {}", pv_mma_tiler)
|
||
|
|
cute.printf("qk_mma_tiler: {}", qk_mma_tiler)
|
||
|
|
|
||
|
|
|
||
|
|
def test():
|
||
|
|
m, n, head_dim = 128, 128, 64
|
||
|
|
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||
|
|
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||
|
|
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
|
||
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||
|
|
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||
|
|
kernel = LayoutCompareKernel()
|
||
|
|
print('Compiling...', flush=True)
|
||
|
|
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||
|
|
print('Running...', flush=True)
|
||
|
|
compiled(mQ, mK, mV, mC, stream)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
test()
|