diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/sf_layout.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/sf_layout.py index 8738fbf3..ec74b4cf 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/sf_layout.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/sf_layout.py @@ -1,156 +1,21 @@ """ -CUTLASS NVFP4 scale factor layout transformation. +CUTLASS NVFP4 scale factor layout — reference documentation. CUTLASS's Sm1xxBlockScaledConfig expects scale factors in a specific -interleaved layout (not simple row-major). This module transforms -our simple (M, K//16) float8_e4m3fn scales into CUTLASS's expected layout. +interleaved layout (not simple row-major). The layout is defined by: -The layout is defined by: SfAtom = Shape, Shape> with Stride, Stride<_0, _1>> - (for K-major, SFVecSize=16) + (SFVecSize=16 for NVFP4 UE4M3 block-16) layout_SFA = tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>) layout_SFB = tile_to_shape(SfAtom{}, make_shape(N, K), Step<_2, _1>) -This creates a 2D layout where scale factors are interleaved in blocks -of 32 rows × (16*4) = 64 columns, with specific stride patterns. +The actual remap from row-major → CUTLASS interleaved layout happens +in the CUDA kernel (remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu), +NOT in Python. This file exists for reference only. + +The CUDA remap uses cute::idx2crd() to invert the CUTLASS layout: +for each linear index in the CUTLASS layout, it computes the logical +(m, k) coordinate and reads from the corresponding row-major position. """ - -import torch -import math - - -def compute_cutlass_sf_size(M: int, K: int, sf_vec_size: int = 16) -> int: - """Compute the number of scale factor elements needed for CUTLASS layout. - - The CUTLASS layout tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>) - produces a layout with a specific size that may differ from M * (K // sf_vec_size). - """ - # SfAtom shape: (32, 4, sf_vec_size, 4) = (32*4, sf_vec_size*4) = (128, 64) for sf_vec_size=16 - # tile_to_shape tiles this atom to cover (M, K) with Step<_2, _1> - # The resulting size = ceil(M / 128) * 128 * ceil(K / 64) * 64 / sf_vec_size... - # Actually it's simpler: the layout size = the total number of elements in the tiled layout - - # For SfAtom with SFVecSize=16: - # SfAtom has shape (32, 4, 16, 4) which is a 2D layout of shape (128, 64) - # (after composition). The actual element count per tile = 128 * 4 = 512 - # (the K dimension is sf_vec_size * 4 = 64, but only 4 "columns" in the SF atom) - - # Actually, let's just compute it empirically. The size of the layout for (M, K) is: - # For SfAtom: Shape<(32,4), (16,4)> = (128, 64) in logical space - # Each atom has 128 * 4 = 512 elements (the (32,4) × (4) in the first dim groups) - # Wait, that's wrong. Let me think about this differently. - - # The SfAtom for K-major, SFVecSize=16: - # Shape, Shape<16,4>> with Stride, Stride<0,1>> - # This is a 2D layout. The shape is (32*4, 16*4) = (128, 64) in terms of - # logical coordinates, but the actual number of elements = 32*4*16*4 = 8192? No. - # - # Actually: Shape, Shape<16,4>> is a nested shape. - # Flattened: 32 * 4 * 16 * 4 = 8192 elements per atom? That can't be right. - # - # No - this is a 2D layout. The shape is (32*4, 16*4) = (128, 64). - # Number of elements = 128 * 64 = 8192 per atom? Still too many. - # - # Actually, the CuTe layout system works differently. - # Shape, Shape<16,4>> has 32*4 = 128 elements in the first mode - # and 16*4 = 64 elements in the second mode. Total = 128 * 64 = 8192? No. - # - # A CuTe Layout maps logical indices to memory offsets. - # For SfAtom: the total number of elements = product of all shape components - # = 32 * 4 * 16 * 4 = 8192. - # But that's per "atom". The tile_to_shape function tiles this. - # - # For the full layout with (M, K): - # size = ceil(M/128) * 128 * ceil(K/64) * 64 ... no. - # Actually the number of SF elements should be roughly M * (K / SFVecSize) - # = M * K / 16. The interleaving just changes the stride pattern, not the count. - # - # The total number of elements in the layout is the product of the shape. - # tile_to_shape preserves the element count (it's just a reshaping/tiling). - # - # For simple (M, K) with SFVecSize=16: - # Total SF elements = M * (K // SFVecSize) = M * (K // 16) - # But CUTLASS may pad to multiples of 128 (rows) and 64 (cols) - M_pad = math.ceil(M / 128) * 128 - K_pad = math.ceil(K / 64) * 64 # 64 = SFVecSize * 4 - return M_pad * (K_pad // sf_vec_size) - - -def transform_scales_to_cutlass_layout( - sf: torch.Tensor, # (..., M, K//16) float8_e4m3fn, simple row-major - M: int, - K: int, - sf_vec_size: int = 16, -) -> torch.Tensor: - """Transform scale factors from simple row-major to CUTLASS interleaved layout. - - CUTLASS expects scale factors in the Sm1xxBlockScaledConfig layout: - - SfAtom = Shape, Shape<16,4>> with Stride, Stride<0,1>> - - tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>) - - This produces an interleaved layout where: - - Scales are grouped in blocks of 128 rows × 4 "scale columns" - - Within each block, 32 groups of 4 interleaved with stride pattern - - The K dimension maps to groups of sf_vec_size consecutive scales - - For our purposes, the simplest approach is to allocate a buffer of the - right size and let CUTLASS's own layout computation handle the indexing. - - Since the total number of elements is the same (just reordered), we can - compute the CUTLASS layout on the Python side and remap. - - For now, we use a simpler approach: pass the scale factors in their - natural layout and adjust the CUTLASS kernel to use a compatible stride. - """ - # The CUTLASS block-scaled GEMM reads scales using TMA descriptors - # that are initialized with the layout_SFA/layout_SFB computed by - # Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA. - # - # The key insight: the CUTLASS 72b example fills scales using - # a cute::Tensor with this layout. If we provide data in a different - # layout, the TMA loads will read wrong addresses. - # - # However, for the _single_ expert GEMM (not grouped), CUTLASS handles - # the TMA descriptor setup internally based on the stride/layout we pass. - # If we pass our scales with a simple row-major stride, CUTLASS will - # still try to read them using the interleaved layout — which is wrong. - # - # The fix: we need to either: - # 1. Remap our data to match the CUTLASS layout (complex) - # 2. Or use a different CUTLASS API that accepts custom strides - # - # For now, let's compute the CUTLASS layout size, allocate, and - # do a simple remap. The remap is: for each (row, k_group) in the - # CUTLASS layout, write the scale factor from our simple layout. - - # For SFVecSize=16, the SfAtom has: - # Shape<(32,4), (16,4)> with Stride<(16,4), (0,1)> - # - # This means for the K-major atom: - # - Row index (first mode): 32 groups of 4, stride (16, 4) - # So element (i, j) in (32, 4) maps to offset i*16 + j*4 - # - Col index (second mode): 16 groups of 4, stride (0, 1) - # So element (k, l) in (16, 4) maps to offset k*0 + l*1 = l - # - # Combined: offset = (i*16 + j*4) + l * 128 - # (because the second mode stride (0, 1) means col dimension - # doesn't contribute to the first mode and vice versa) - # - # Actually, for a 2D CuTe layout: - # Shape<(32,4), (16,4)> = (128, 64) - # Stride<(16,4), (0,1)> means: - # First mode (row): stride 16, 4 -> 32*4=128 rows, row stride = (16,4) - # Second mode (col): stride 0, 1 -> 16*4=64 cols, col stride = (0,1) - # - # For element at (r, c) where r = i*4+j (i in 0..31, j in 0..3) - # c = k*4+l (k in 0..15, l in 0..3): - # offset = i*16 + j*4 + l*128 - # - # Wait, that's not right either. Let me just compute the layout properly. - # The total number of SF elements should be M * (K // 16) for our data. - - # For now, let's just return the scales as-is and see if CUTLASS can handle - # a simple stride. The real fix would be to use CUTLASS's own layout helpers. - return sf