From d29d6b575f48ff25ebb72b3e8aa954eda442e7a8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 08:20:56 +0000 Subject: [PATCH] add UMMA descriptor diagnostic script --- scripts/print_umma_desc.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 scripts/print_umma_desc.py diff --git a/scripts/print_umma_desc.py b/scripts/print_umma_desc.py new file mode 100644 index 00000000..44bcb55c --- /dev/null +++ b/scripts/print_umma_desc.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" +Print UMMA SMEM descriptors and layout for FMHA decode. + +This script uses CuTeDSL to construct the exact SMEM layout and UMMA +descriptors that the FMHA kernel uses. We then hardcode these values +in our raw CUDA kernel. +""" +import sys +sys.path.insert(0, '/root/cutlass/python/CuTeDSL') + +import cutlass +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.utils.blackwell_helpers import OperandMajorMode +import cute + +# FMHA decode configuration +HEAD_DIM = 64 +M = 128 # head-packed rows +SK_TILE = 128 # KV tile size + +# BF16 dtype +bf16 = cutlass.float16 # Will use bf16 in the actual kernel + +# Construct SMEM layouts using the same code the FMHA kernel uses +# MN-major A (Q): (128, 64) BF16 +# K-major B (K): (128, 64) BF16 + +try: + # MN-major layout for Q + q_layout = sm100_utils.make_smem_layout_a( + major_mode=OperandMajorMode.MN_MAJOR, + smem_tile_shape=cute.make_shape(M, HEAD_DIM), + element_type=bf16, + stage=1, + ) + print(f"Q SMEM layout: {q_layout}") + print(f"Q SMEM shape: {cute.shape(q_layout)}") + print(f"Q SMEM stride: {cute.stride(q_layout)}") + print(f"Q SMEM size (elements): {cute.size(q_layout)}") + print(f"Q SMEM size (bytes): {cute.size(q_layout) * 2}") + + # K-major layout for K + k_layout = sm100_utils.make_smem_layout_b( + major_mode=OperandMajorMode.K_MAJOR, + smem_tile_shape=cute.make_shape(M, HEAD_DIM), + element_type=bf16, + stage=1, + ) + print(f"\nK SMEM layout: {k_layout}") + print(f"K SMEM shape: {cute.shape(k_layout)}") + print(f"K SMEM stride: {cute.stride(k_layout)}") + print(f"K SMEM size (elements): {cute.size(k_layout)}") + + # Print a few element offsets to understand the swizzle pattern + print("\nQ swizzle offsets (row, col) -> offset:") + for row in range(4): + for col in range(8): + offset = q_layout(row, col) + print(f" ({row},{col}) -> {offset}", end="") + print() + + print("\nK swizzle offsets (row, col) -> offset:") + for row in range(4): + for col in range(8): + offset = k_layout(row, col) + print(f" ({row},{col}) -> {offset}", end="") + print() + + # Construct UMMA descriptors + # We need to use cute.make_umma_desc which takes a tensor, not raw values + # But we can construct a tensor with the layout and extract the descriptor + print("\n=== UMMA Descriptor Construction ===") + print("(Need to construct CuTe tensor with the SMEM layout and call make_umma_desc)") + +except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc()