Files
nvfp4-megamoe-kernel/scripts/compute_swizzle.py

88 lines
3.3 KiB
Python

#!/usr/bin/env python3
"""
Print UMMA SMEM layout offsets for FMHA decode.
Directly compute the swizzle pattern without CuTeDSL imports.
"""
import sys
# SWIZZLE_128B atom layout for BF16 MN-major:
# Atom shape: (1024, 8) BF16, stride (1, 1024)
# tile_to_shape for (128, 64):
# m is always < 1024, so m/1024 = 0
# n needs 64/8 = 8 atoms
# offset = m + (n % 8) * 1024 + (n / 8) * 8192
#
# Swizzle<3,4,3>:
# swizzled = offset ^ ((((offset >> 3) ^ (offset >> 7)) & 0x7) << 3)
def swizzle_3_4_3(offset):
"""Apply Swizzle<3,4,3> to an element offset."""
return offset ^ ((((offset >> 3) ^ (offset >> 7)) & 0x7) << 3)
def mn_sw128_offset(row, col, hd=64):
"""Compute swizzled SMEM offset for MN-major (128, HD) BF16 matrix."""
# Logical offset in BF16 elements
logical = row + (col % 8) * 1024 + (col // 8) * 8192
# Apply swizzle
return swizzle_3_4_3(logical)
def k_sw128_offset(row, col, hd=64):
"""Compute swizzled SMEM offset for K-major (128, HD) BF16 matrix.
K_SW128 atom: (8, 1024) BF16, stride (1, 8)
For (128, 64):
k ranges 0..127, mn ranges 0..63
logical = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024
= (k % 8) + mn * 8 + (k / 8) * 8192
"""
logical = (row % 8) + col * 8 + (row // 8) * 8192
return swizzle_3_4_3(logical)
# Print Q (MN-major) swizzle pattern
print("=== MN_SW128 Q layout (row, col) -> offset (BF16 elements) ===")
max_offset = 0
for row in [0, 1, 2, 3, 127]:
for col in [0, 1, 7, 8, 15, 16, 63]:
offset = mn_sw128_offset(row, col)
max_offset = max(max_offset, offset)
print(f" ({row:3d},{col:2d}) -> {offset:6d}", end="")
print()
print(f"\nMax Q offset: {max_offset} BF16 = {max_offset * 2} bytes = {max_offset * 2 / 1024:.1f} KB")
# Print K (K-major) swizzle pattern
print("\n=== K_SW128 K layout (row, col) -> offset (BF16 elements) ===")
max_offset_k = 0
for row in [0, 1, 7, 8, 15, 127]:
for col in [0, 1, 7, 8, 63]:
offset = k_sw128_offset(row, col)
max_offset_k = max(max_offset_k, offset)
print(f" ({row:3d},{col:2d}) -> {offset:6d}", end="")
print()
print(f"\nMax K offset: {max_offset_k} BF16 = {max_offset_k * 2} bytes = {max_offset_k * 2 / 1024:.1f} KB")
# Verify: does the swizzle stay within the matrix size?
matrix_size = 128 * 64 # 8192 BF16
print(f"\nMatrix size: {matrix_size} BF16 = {matrix_size * 2} bytes = {matrix_size * 2 / 1024:.1f} KB")
print(f"Q layout footprint: {max_offset + 1} BF16 ({(max_offset + 1) * 2 / 1024:.1f} KB)")
print(f"K layout footprint: {max_offset_k + 1} BF16 ({(max_offset_k + 1) * 2 / 1024:.1f} KB)")
# Check if any offsets exceed the simple row-major layout
row_major_max = 128 * 64 - 1 # 8191
q_overflow = any(mn_sw128_offset(r, c) > row_major_max
for r in range(128) for c in range(64))
k_overflow = any(k_sw128_offset(r, c) > row_major_max
for r in range(128) for c in range(64))
print(f"\nQ offsets exceed row-major bounds: {q_overflow}")
print(f"K offsets exceed row-major bounds: {k_overflow}")
# Count unique offsets (verify no collisions)
q_offsets = set()
k_offsets = set()
for r in range(128):
for c in range(64):
q_offsets.add(mn_sw128_offset(r, c))
k_offsets.add(k_sw128_offset(r, c))
print(f"Q unique offsets: {len(q_offsets)} (expected {128*64})")
print(f"K unique offsets: {len(k_offsets)} (expected {128*64})")