80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Print UMMA SMEM descriptors and layout for FMHA decode.
|
|
|
|
This script uses CuTeDSL to construct the exact SMEM layout and UMMA
|
|
descriptors that the FMHA kernel uses. We then hardcode these values
|
|
in our raw CUDA kernel.
|
|
"""
|
|
import sys
|
|
sys.path.insert(0, '/root/cutlass/python/CuTeDSL')
|
|
|
|
import cutlass
|
|
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
from cutlass.utils.blackwell_helpers import OperandMajorMode
|
|
import cute
|
|
|
|
# FMHA decode configuration
|
|
HEAD_DIM = 64
|
|
M = 128 # head-packed rows
|
|
SK_TILE = 128 # KV tile size
|
|
|
|
# BF16 dtype
|
|
bf16 = cutlass.float16 # Will use bf16 in the actual kernel
|
|
|
|
# Construct SMEM layouts using the same code the FMHA kernel uses
|
|
# MN-major A (Q): (128, 64) BF16
|
|
# K-major B (K): (128, 64) BF16
|
|
|
|
try:
|
|
# MN-major layout for Q
|
|
q_layout = sm100_utils.make_smem_layout_a(
|
|
major_mode=OperandMajorMode.MN_MAJOR,
|
|
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
|
|
element_type=bf16,
|
|
stage=1,
|
|
)
|
|
print(f"Q SMEM layout: {q_layout}")
|
|
print(f"Q SMEM shape: {cute.shape(q_layout)}")
|
|
print(f"Q SMEM stride: {cute.stride(q_layout)}")
|
|
print(f"Q SMEM size (elements): {cute.size(q_layout)}")
|
|
print(f"Q SMEM size (bytes): {cute.size(q_layout) * 2}")
|
|
|
|
# K-major layout for K
|
|
k_layout = sm100_utils.make_smem_layout_b(
|
|
major_mode=OperandMajorMode.K_MAJOR,
|
|
smem_tile_shape=cute.make_shape(M, HEAD_DIM),
|
|
element_type=bf16,
|
|
stage=1,
|
|
)
|
|
print(f"\nK SMEM layout: {k_layout}")
|
|
print(f"K SMEM shape: {cute.shape(k_layout)}")
|
|
print(f"K SMEM stride: {cute.stride(k_layout)}")
|
|
print(f"K SMEM size (elements): {cute.size(k_layout)}")
|
|
|
|
# Print a few element offsets to understand the swizzle pattern
|
|
print("\nQ swizzle offsets (row, col) -> offset:")
|
|
for row in range(4):
|
|
for col in range(8):
|
|
offset = q_layout(row, col)
|
|
print(f" ({row},{col}) -> {offset}", end="")
|
|
print()
|
|
|
|
print("\nK swizzle offsets (row, col) -> offset:")
|
|
for row in range(4):
|
|
for col in range(8):
|
|
offset = k_layout(row, col)
|
|
print(f" ({row},{col}) -> {offset}", end="")
|
|
print()
|
|
|
|
# Construct UMMA descriptors
|
|
# We need to use cute.make_umma_desc which takes a tensor, not raw values
|
|
# But we can construct a tensor with the layout and extract the descriptor
|
|
print("\n=== UMMA Descriptor Construction ===")
|
|
print("(Need to construct CuTe tensor with the SMEM layout and call make_umma_desc)")
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|