add UMMA descriptor diagnostic script
This commit is contained in:
79
scripts/print_umma_desc.py
Normal file
79
scripts/print_umma_desc.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Print UMMA SMEM descriptors and layout for FMHA decode.
|
||||
|
||||
This script uses CuTeDSL to construct the exact SMEM layout and UMMA
|
||||
descriptors that the FMHA kernel uses. We then hardcode these values
|
||||
in our raw CUDA kernel.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '/root/cutlass/python/CuTeDSL')
|
||||
|
||||
import cutlass
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.utils.blackwell_helpers import OperandMajorMode
|
||||
import cute
|
||||
|
||||
# FMHA decode configuration
|
||||
HEAD_DIM = 64
|
||||
M = 128 # head-packed rows
|
||||
SK_TILE = 128 # KV tile size
|
||||
|
||||
# BF16 dtype
|
||||
bf16 = cutlass.float16 # Will use bf16 in the actual kernel
|
||||
|
||||
# Construct SMEM layouts using the same code the FMHA kernel uses
|
||||
# MN-major A (Q): (128, 64) BF16
|
||||
# K-major B (K): (128, 64) BF16
|
||||
|
||||
try:
|
||||
# MN-major layout for Q
|
||||
q_layout = sm100_utils.make_smem_layout_a(
|
||||
major_mode=OperandMajorMode.MN_MAJOR,
|
||||
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
|
||||
element_type=bf16,
|
||||
stage=1,
|
||||
)
|
||||
print(f"Q SMEM layout: {q_layout}")
|
||||
print(f"Q SMEM shape: {cute.shape(q_layout)}")
|
||||
print(f"Q SMEM stride: {cute.stride(q_layout)}")
|
||||
print(f"Q SMEM size (elements): {cute.size(q_layout)}")
|
||||
print(f"Q SMEM size (bytes): {cute.size(q_layout) * 2}")
|
||||
|
||||
# K-major layout for K
|
||||
k_layout = sm100_utils.make_smem_layout_b(
|
||||
major_mode=OperandMajorMode.K_MAJOR,
|
||||
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
|
||||
element_type=bf16,
|
||||
stage=1,
|
||||
)
|
||||
print(f"\nK SMEM layout: {k_layout}")
|
||||
print(f"K SMEM shape: {cute.shape(k_layout)}")
|
||||
print(f"K SMEM stride: {cute.stride(k_layout)}")
|
||||
print(f"K SMEM size (elements): {cute.size(k_layout)}")
|
||||
|
||||
# Print a few element offsets to understand the swizzle pattern
|
||||
print("\nQ swizzle offsets (row, col) -> offset:")
|
||||
for row in range(4):
|
||||
for col in range(8):
|
||||
offset = q_layout(row, col)
|
||||
print(f" ({row},{col}) -> {offset}", end="")
|
||||
print()
|
||||
|
||||
print("\nK swizzle offsets (row, col) -> offset:")
|
||||
for row in range(4):
|
||||
for col in range(8):
|
||||
offset = k_layout(row, col)
|
||||
print(f" ({row},{col}) -> {offset}", end="")
|
||||
print()
|
||||
|
||||
# Construct UMMA descriptors
|
||||
# We need to use cute.make_umma_desc which takes a tensor, not raw values
|
||||
# But we can construct a tensor with the layout and extract the descriptor
|
||||
print("\n=== UMMA Descriptor Construction ===")
|
||||
print("(Need to construct CuTe tensor with the SMEM layout and call make_umma_desc)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Reference in New Issue
Block a user