54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
"""
|
|
Minimal CuTeDSL script to print UMMA descriptors for FMHA Q and K.
|
|
Run on B200 with: python3 print_umma_desc_cute.py
|
|
"""
|
|
import sys
|
|
sys.path.insert(0, '/root/cutlass/python/CuTeDSL')
|
|
|
|
import torch
|
|
import cutlass
|
|
import cute
|
|
|
|
# Configuration matching our FMHA decode
|
|
HEAD_DIM = 64
|
|
M = 128
|
|
N = 128 # KV tile
|
|
|
|
# Create a simple kernel that prints the UMMA descriptors
|
|
@cute.jit
|
|
def print_umma_descriptors():
|
|
# Construct SMEM layouts the same way the FMHA does
|
|
# Using the SM100 utils
|
|
from cutlass.utils.blackwell_helpers import make_smem_layout_a, make_smem_layout_b, OperandMajorMode
|
|
|
|
# Q: MN-major, (M, HEAD_DIM) BF16
|
|
q_layout = make_smem_layout_a(
|
|
major_mode=OperandMajorMode.MN_MAJOR,
|
|
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
|
|
element_type=cutlass.float16, # BF16
|
|
stage=1,
|
|
)
|
|
|
|
# K: K-major, (N, HEAD_DIM) BF16
|
|
k_layout = make_smem_layout_b(
|
|
major_mode=OperandMajorMode.K_MAJOR,
|
|
smem_tile_shape=cute.make_shape(N, HEAD_DIM),
|
|
element_type=cutlass.float16,
|
|
stage=1,
|
|
)
|
|
|
|
cute.printf("Q layout shape: {}", cute.shape(q_layout))
|
|
cute.printf("Q layout size: {}", cute.size(q_layout))
|
|
cute.printf("K layout shape: {}", cute.shape(k_layout))
|
|
cute.printf("K layout size: {}", cute.size(k_layout))
|
|
|
|
if __name__ == "__main__":
|
|
# Just trigger the JIT compilation to see the prints
|
|
print("Compiling...")
|
|
try:
|
|
print_umma_descriptors()
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|