#!/usr/bin/env python3 """ Print UMMA SMEM layout offsets for FMHA decode. Directly compute the swizzle pattern without CuTeDSL imports. """ import sys # SWIZZLE_128B atom layout for BF16 MN-major: # Atom shape: (1024, 8) BF16, stride (1, 1024) # tile_to_shape for (128, 64): # m is always < 1024, so m/1024 = 0 # n needs 64/8 = 8 atoms # offset = m + (n % 8) * 1024 + (n / 8) * 8192 # # Swizzle<3,4,3>: # swizzled = offset ^ ((((offset >> 3) ^ (offset >> 7)) & 0x7) << 3) def swizzle_3_4_3(offset): """Apply Swizzle<3,4,3> to an element offset.""" return offset ^ ((((offset >> 3) ^ (offset >> 7)) & 0x7) << 3) def mn_sw128_offset(row, col, hd=64): """Compute swizzled SMEM offset for MN-major (128, HD) BF16 matrix.""" # Logical offset in BF16 elements logical = row + (col % 8) * 1024 + (col // 8) * 8192 # Apply swizzle return swizzle_3_4_3(logical) def k_sw128_offset(row, col, hd=64): """Compute swizzled SMEM offset for K-major (128, HD) BF16 matrix. K_SW128 atom: (8, 1024) BF16, stride (1, 8) For (128, 64): k ranges 0..127, mn ranges 0..63 logical = (k % 8) + (mn % 1024) * 8 + (k / 8) * 8 * 1024 = (k % 8) + mn * 8 + (k / 8) * 8192 """ logical = (row % 8) + col * 8 + (row // 8) * 8192 return swizzle_3_4_3(logical) # Print Q (MN-major) swizzle pattern print("=== MN_SW128 Q layout (row, col) -> offset (BF16 elements) ===") max_offset = 0 for row in [0, 1, 2, 3, 127]: for col in [0, 1, 7, 8, 15, 16, 63]: offset = mn_sw128_offset(row, col) max_offset = max(max_offset, offset) print(f" ({row:3d},{col:2d}) -> {offset:6d}", end="") print() print(f"\nMax Q offset: {max_offset} BF16 = {max_offset * 2} bytes = {max_offset * 2 / 1024:.1f} KB") # Print K (K-major) swizzle pattern print("\n=== K_SW128 K layout (row, col) -> offset (BF16 elements) ===") max_offset_k = 0 for row in [0, 1, 7, 8, 15, 127]: for col in [0, 1, 7, 8, 63]: offset = k_sw128_offset(row, col) max_offset_k = max(max_offset_k, offset) print(f" ({row:3d},{col:2d}) -> {offset:6d}", end="") print() print(f"\nMax K offset: {max_offset_k} BF16 = {max_offset_k * 2} bytes = {max_offset_k * 2 / 1024:.1f} KB") # Verify: does the swizzle stay within the matrix size? matrix_size = 128 * 64 # 8192 BF16 print(f"\nMatrix size: {matrix_size} BF16 = {matrix_size * 2} bytes = {matrix_size * 2 / 1024:.1f} KB") print(f"Q layout footprint: {max_offset + 1} BF16 ({(max_offset + 1) * 2 / 1024:.1f} KB)") print(f"K layout footprint: {max_offset_k + 1} BF16 ({(max_offset_k + 1) * 2 / 1024:.1f} KB)") # Check if any offsets exceed the simple row-major layout row_major_max = 128 * 64 - 1 # 8191 q_overflow = any(mn_sw128_offset(r, c) > row_major_max for r in range(128) for c in range(64)) k_overflow = any(k_sw128_offset(r, c) > row_major_max for r in range(128) for c in range(64)) print(f"\nQ offsets exceed row-major bounds: {q_overflow}") print(f"K offsets exceed row-major bounds: {k_overflow}") # Count unique offsets (verify no collisions) q_offsets = set() k_offsets = set() for r in range(128): for c in range(64): q_offsets.add(mn_sw128_offset(r, c)) k_offsets.add(k_sw128_offset(r, c)) print(f"Q unique offsets: {len(q_offsets)} (expected {128*64})") print(f"K unique offsets: {len(k_offsets)} (expected {128*64})")