add UMMA descriptor diagnostic script

This commit is contained in:
2026-05-28 08:20:56 +00:00
parent ab84ad0f86
commit d29d6b575f

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python3
"""
Print UMMA SMEM descriptors and layout for FMHA decode.
This script uses CuTeDSL to construct the exact SMEM layout and UMMA
descriptors that the FMHA kernel uses. We then hardcode these values
in our raw CUDA kernel.
"""
import sys
sys.path.insert(0, '/root/cutlass/python/CuTeDSL')
import cutlass
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.utils.blackwell_helpers import OperandMajorMode
import cute
# FMHA decode configuration
HEAD_DIM = 64
M = 128 # head-packed rows
SK_TILE = 128 # KV tile size
# BF16 dtype
bf16 = cutlass.float16 # Will use bf16 in the actual kernel
# Construct SMEM layouts using the same code the FMHA kernel uses
# MN-major A (Q): (128, 64) BF16
# K-major B (K): (128, 64) BF16
try:
# MN-major layout for Q
q_layout = sm100_utils.make_smem_layout_a(
major_mode=OperandMajorMode.MN_MAJOR,
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
element_type=bf16,
stage=1,
)
print(f"Q SMEM layout: {q_layout}")
print(f"Q SMEM shape: {cute.shape(q_layout)}")
print(f"Q SMEM stride: {cute.stride(q_layout)}")
print(f"Q SMEM size (elements): {cute.size(q_layout)}")
print(f"Q SMEM size (bytes): {cute.size(q_layout) * 2}")
# K-major layout for K
k_layout = sm100_utils.make_smem_layout_b(
major_mode=OperandMajorMode.K_MAJOR,
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
element_type=bf16,
stage=1,
)
print(f"\nK SMEM layout: {k_layout}")
print(f"K SMEM shape: {cute.shape(k_layout)}")
print(f"K SMEM stride: {cute.stride(k_layout)}")
print(f"K SMEM size (elements): {cute.size(k_layout)}")
# Print a few element offsets to understand the swizzle pattern
print("\nQ swizzle offsets (row, col) -> offset:")
for row in range(4):
for col in range(8):
offset = q_layout(row, col)
print(f" ({row},{col}) -> {offset}", end="")
print()
print("\nK swizzle offsets (row, col) -> offset:")
for row in range(4):
for col in range(8):
offset = k_layout(row, col)
print(f" ({row},{col}) -> {offset}", end="")
print()
# Construct UMMA descriptors
# We need to use cute.make_umma_desc which takes a tensor, not raw values
# But we can construct a tensor with the layout and extract the descriptor
print("\n=== UMMA Descriptor Construction ===")
print("(Need to construct CuTe tensor with the SMEM layout and call make_umma_desc)")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()