Files
nvfp4-megamoe-kernel/scripts/print_umma_desc_cute.py

54 lines
1.6 KiB
Python

"""
Minimal CuTeDSL script to print UMMA descriptors for FMHA Q and K.
Run on B200 with: python3 print_umma_desc_cute.py
"""
import sys
sys.path.insert(0, '/root/cutlass/python/CuTeDSL')
import torch
import cutlass
import cute
# Configuration matching our FMHA decode
HEAD_DIM = 64
M = 128
N = 128 # KV tile
# Create a simple kernel that prints the UMMA descriptors
@cute.jit
def print_umma_descriptors():
# Construct SMEM layouts the same way the FMHA does
# Using the SM100 utils
from cutlass.utils.blackwell_helpers import make_smem_layout_a, make_smem_layout_b, OperandMajorMode
# Q: MN-major, (M, HEAD_DIM) BF16
q_layout = make_smem_layout_a(
major_mode=OperandMajorMode.MN_MAJOR,
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
element_type=cutlass.float16, # BF16
stage=1,
)
# K: K-major, (N, HEAD_DIM) BF16
k_layout = make_smem_layout_b(
major_mode=OperandMajorMode.K_MAJOR,
smem_tile_shape=cute.make_shape(N, HEAD_DIM),
element_type=cutlass.float16,
stage=1,
)
cute.printf("Q layout shape: {}", cute.shape(q_layout))
cute.printf("Q layout size: {}", cute.size(q_layout))
cute.printf("K layout shape: {}", cute.shape(k_layout))
cute.printf("K layout size: {}", cute.size(k_layout))
if __name__ == "__main__":
# Just trigger the JIT compilation to see the prints
print("Compiling...")
try:
print_umma_descriptors()
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()