88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Print UMMA SMEM layout offsets for FMHA decode.
|
|
Directly compute the swizzle pattern without CuTeDSL imports.
|
|
"""
|
|
import sys
|
|
|
|
# SWIZZLE_128B atom layout for BF16 MN-major:
|
|
# Atom shape: (1024, 8) BF16, stride (1, 1024)
|
|
# tile_to_shape for (128, 64):
|
|
# m is always < 1024, so m/1024 = 0
|
|
# n needs 64/8 = 8 atoms
|
|
# offset = m + (n % 8) * 1024 + (n / 8) * 8192
|
|
#
|
|
# Swizzle<3,4,3>:
|
|
# swizzled = offset ^ ((((offset >> 3) ^ (offset >> 7)) & 0x7) << 3)
|
|
|
|
def swizzle_3_4_3(offset):
|
|
"""Apply Swizzle<3,4,3> to an element offset."""
|
|
return offset ^ ((((offset >> 3) ^ (offset >> 7)) & 0x7) << 3)
|
|
|
|
def mn_sw128_offset(row, col, hd=64):
|
|
"""Compute swizzled SMEM offset for MN-major (128, HD) BF16 matrix."""
|
|
# Logical offset in BF16 elements
|
|
logical = row + (col % 8) * 1024 + (col // 8) * 8192
|
|
# Apply swizzle
|
|
return swizzle_3_4_3(logical)
|
|
|
|
def k_sw128_offset(row, col, hd=64):
|
|
"""Compute swizzled SMEM offset for K-major (128, HD) BF16 matrix.
|
|
K_SW128 atom: (8, 1024) BF16, stride (1, 8)
|
|
For (128, 64):
|
|
k ranges 0..127, mn ranges 0..63
|
|
logical = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024
|
|
= (k % 8) + mn * 8 + (k / 8) * 8192
|
|
"""
|
|
logical = (row % 8) + col * 8 + (row // 8) * 8192
|
|
return swizzle_3_4_3(logical)
|
|
|
|
# Print Q (MN-major) swizzle pattern
|
|
print("=== MN_SW128 Q layout (row, col) -> offset (BF16 elements) ===")
|
|
max_offset = 0
|
|
for row in [0, 1, 2, 3, 127]:
|
|
for col in [0, 1, 7, 8, 15, 16, 63]:
|
|
offset = mn_sw128_offset(row, col)
|
|
max_offset = max(max_offset, offset)
|
|
print(f" ({row:3d},{col:2d}) -> {offset:6d}", end="")
|
|
print()
|
|
|
|
print(f"\nMax Q offset: {max_offset} BF16 = {max_offset * 2} bytes = {max_offset * 2 / 1024:.1f} KB")
|
|
|
|
# Print K (K-major) swizzle pattern
|
|
print("\n=== K_SW128 K layout (row, col) -> offset (BF16 elements) ===")
|
|
max_offset_k = 0
|
|
for row in [0, 1, 7, 8, 15, 127]:
|
|
for col in [0, 1, 7, 8, 63]:
|
|
offset = k_sw128_offset(row, col)
|
|
max_offset_k = max(max_offset_k, offset)
|
|
print(f" ({row:3d},{col:2d}) -> {offset:6d}", end="")
|
|
print()
|
|
|
|
print(f"\nMax K offset: {max_offset_k} BF16 = {max_offset_k * 2} bytes = {max_offset_k * 2 / 1024:.1f} KB")
|
|
|
|
# Verify: does the swizzle stay within the matrix size?
|
|
matrix_size = 128 * 64 # 8192 BF16
|
|
print(f"\nMatrix size: {matrix_size} BF16 = {matrix_size * 2} bytes = {matrix_size * 2 / 1024:.1f} KB")
|
|
print(f"Q layout footprint: {max_offset + 1} BF16 ({(max_offset + 1) * 2 / 1024:.1f} KB)")
|
|
print(f"K layout footprint: {max_offset_k + 1} BF16 ({(max_offset_k + 1) * 2 / 1024:.1f} KB)")
|
|
|
|
# Check if any offsets exceed the simple row-major layout
|
|
row_major_max = 128 * 64 - 1 # 8191
|
|
q_overflow = any(mn_sw128_offset(r, c) > row_major_max
|
|
for r in range(128) for c in range(64))
|
|
k_overflow = any(k_sw128_offset(r, c) > row_major_max
|
|
for r in range(128) for c in range(64))
|
|
print(f"\nQ offsets exceed row-major bounds: {q_overflow}")
|
|
print(f"K offsets exceed row-major bounds: {k_overflow}")
|
|
|
|
# Count unique offsets (verify no collisions)
|
|
q_offsets = set()
|
|
k_offsets = set()
|
|
for r in range(128):
|
|
for c in range(64):
|
|
q_offsets.add(mn_sw128_offset(r, c))
|
|
k_offsets.add(k_sw128_offset(r, c))
|
|
print(f"Q unique offsets: {len(q_offsets)} (expected {128*64})")
|
|
print(f"K unique offsets: {len(k_offsets)} (expected {128*64})")
|