Files
nvfp4-megamoe-kernel/scripts/dump_umma_desc.py
2026-05-28 08:59:19 +00:00

34 lines
938 B
Python

"""
Dump FMHA SMEM layout info and try to extract UMMA descriptor values.
"""
import torch
import sys
sys.path.insert(0, '.')
from dsv4.kernels.attention.fmha import FmhaKernel
kernel = FmhaKernel(head_dim=64, use_smem_p=False, normalize=True)
# The _s objects are (inner_swizzle, outer_layout) tuples
q_s = kernel.q_smem_s
k_s = kernel.k_smem_s
print(f"q_smem_s type: {type(q_s)}")
print(f"q_smem_s: {q_s}")
print(f"k_smem_s: {k_s}")
# Try to access inner/outer
if hasattr(q_s, 'inner'):
print(f"Q inner (swizzle): {q_s.inner}")
print(f"Q outer (layout): {q_s.outer}")
elif isinstance(q_s, tuple):
print(f"Q is tuple of length {len(q_s)}")
for i, x in enumerate(q_s):
print(f" [{i}]: {type(x)} = {x}")
# SMEM sizes from the kernel
print(f"\nQ SMEM size: {kernel.q_smem_size}")
print(f"K SMEM size: {kernel.k_smem_size}")
print(f"V SMEM size: {kernel.v_smem_size}")
print(f"Total SMEM: {kernel.smem_size}")