34 lines
938 B
Python
34 lines
938 B
Python
"""
|
|
Dump FMHA SMEM layout info and try to extract UMMA descriptor values.
|
|
"""
|
|
import torch
|
|
import sys
|
|
sys.path.insert(0, '.')
|
|
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
kernel = FmhaKernel(head_dim=64, use_smem_p=False, normalize=True)
|
|
|
|
# The _s objects are (inner_swizzle, outer_layout) tuples
|
|
q_s = kernel.q_smem_s
|
|
k_s = kernel.k_smem_s
|
|
|
|
print(f"q_smem_s type: {type(q_s)}")
|
|
print(f"q_smem_s: {q_s}")
|
|
print(f"k_smem_s: {k_s}")
|
|
|
|
# Try to access inner/outer
|
|
if hasattr(q_s, 'inner'):
|
|
print(f"Q inner (swizzle): {q_s.inner}")
|
|
print(f"Q outer (layout): {q_s.outer}")
|
|
elif isinstance(q_s, tuple):
|
|
print(f"Q is tuple of length {len(q_s)}")
|
|
for i, x in enumerate(q_s):
|
|
print(f" [{i}]: {type(x)} = {x}")
|
|
|
|
# SMEM sizes from the kernel
|
|
print(f"\nQ SMEM size: {kernel.q_smem_size}")
|
|
print(f"K SMEM size: {kernel.k_smem_size}")
|
|
print(f"V SMEM size: {kernel.v_smem_size}")
|
|
print(f"Total SMEM: {kernel.smem_size}")
|