#!/usr/bin/env python3 """ Print UMMA SMEM descriptors and layout for FMHA decode. This script uses CuTeDSL to construct the exact SMEM layout and UMMA descriptors that the FMHA kernel uses. We then hardcode these values in our raw CUDA kernel. """ import sys sys.path.insert(0, '/root/cutlass/python/CuTeDSL') import cutlass import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.utils.blackwell_helpers import OperandMajorMode import cute # FMHA decode configuration HEAD_DIM = 64 M = 128 # head-packed rows SK_TILE = 128 # KV tile size # BF16 dtype bf16 = cutlass.float16 # Will use bf16 in the actual kernel # Construct SMEM layouts using the same code the FMHA kernel uses # MN-major A (Q): (128, 64) BF16 # K-major B (K): (128, 64) BF16 try: # MN-major layout for Q q_layout = sm100_utils.make_smem_layout_a( major_mode=OperandMajorMode.MN_MAJOR, smem_tile_shape=cute.make_shape(M, HEAD_DIM), element_type=bf16, stage=1, ) print(f"Q SMEM layout: {q_layout}") print(f"Q SMEM shape: {cute.shape(q_layout)}") print(f"Q SMEM stride: {cute.stride(q_layout)}") print(f"Q SMEM size (elements): {cute.size(q_layout)}") print(f"Q SMEM size (bytes): {cute.size(q_layout) * 2}") # K-major layout for K k_layout = sm100_utils.make_smem_layout_b( major_mode=OperandMajorMode.K_MAJOR, smem_tile_shape=cute.make_shape(M, HEAD_DIM), element_type=bf16, stage=1, ) print(f"\nK SMEM layout: {k_layout}") print(f"K SMEM shape: {cute.shape(k_layout)}") print(f"K SMEM stride: {cute.stride(k_layout)}") print(f"K SMEM size (elements): {cute.size(k_layout)}") # Print a few element offsets to understand the swizzle pattern print("\nQ swizzle offsets (row, col) -> offset:") for row in range(4): for col in range(8): offset = q_layout(row, col) print(f" ({row},{col}) -> {offset}", end="") print() print("\nK swizzle offsets (row, col) -> offset:") for row in range(4): for col in range(8): offset = k_layout(row, col) print(f" ({row},{col}) -> {offset}", end="") print() # Construct UMMA descriptors # We need to use cute.make_umma_desc which takes a tensor, not raw values # But we can construct a tensor with the layout and extract the descriptor print("\n=== UMMA Descriptor Construction ===") print("(Need to construct CuTe tensor with the SMEM layout and call make_umma_desc)") except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc()