73 lines
3.6 KiB
Python
73 lines
3.6 KiB
Python
"""Print V SMEM layouts for (128,64) and (128,128) PV. Must run inside JIT."""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
|
|
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
|
|
class SmemLayoutKernel:
|
|
def __init__(self):
|
|
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
|
self.qk_acc_dtype = Float32
|
|
self.use_2cta_instrs = False; self.cluster_shape_mn = (1, 1)
|
|
self.cta_group = tcgen05.CtaGroup.ONE
|
|
self.epilogue_warp_id = (0, 1, 2, 3)
|
|
self.mma_warp_id = 4; self.tma_warp_id = 5
|
|
self.threads_per_cta = 192; self.num_c_stage = 2
|
|
|
|
@cute.jit
|
|
def __call__(self, q, k, v, c, stream):
|
|
self.q_dtype = q.element_type; self.o_dtype = c.element_type
|
|
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
|
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
|
c_layout = LayoutEnum.from_tensor(c)
|
|
|
|
# QK
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, a_major, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (128, 128, qk_inst_k * 4)
|
|
b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, 1)
|
|
print(f"QK B SMEM: outer={b_smem_s.outer}, inner={b_smem_s.inner}")
|
|
|
|
# PV (128, 64)
|
|
pv_mma_64 = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, OperandMajorMode.K, v_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 64), tcgen05.OperandSource.TMEM)
|
|
pv_mma_tiler_64 = (128, 64, 128)
|
|
v_smem_64 = utils.sm100.make_smem_layout_b(pv_mma_64, pv_mma_tiler_64, BFloat16, 1)
|
|
print(f"PV(128,64) V SMEM: outer={v_smem_64.outer}, inner={v_smem_64.inner}")
|
|
|
|
# PV (128, 128)
|
|
pv_mma_128 = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, OperandMajorMode.K, v_major,
|
|
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
|
|
pv_mma_tiler_128 = (128, 128, 128)
|
|
v_smem_128 = utils.sm100.make_smem_layout_b(pv_mma_128, pv_mma_tiler_128, BFloat16, 1)
|
|
print(f"PV(128,128) V SMEM: outer={v_smem_128.outer}, inner={v_smem_128.inner}")
|
|
|
|
# Also print the PV MMA atom shapes
|
|
print(f"PV(128,64) MMA shape_mnk={pv_mma_64.shape_mnk}")
|
|
print(f"PV(128,128) MMA shape_mnk={pv_mma_128.shape_mnk}")
|
|
|
|
torch.manual_seed(42)
|
|
m, n, head_dim = 128, 128, 64
|
|
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
|
v_data[0, 0] = 1.0
|
|
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
|
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
kernel = SmemLayoutKernel()
|
|
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|