docs: clarify SF layout remap is in CUDA, not sf_layout.py

sf_layout.py was a no-op (return sf) but the actual remap happens
in remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu. Updated
sf_layout.py to pure reference docs so nobody gets confused again.
This commit is contained in:
2026-05-14 13:04:31 +00:00
parent 16f91ff0e1
commit 80495c0cd6

View File

@@ -1,156 +1,21 @@
"""
CUTLASS NVFP4 scale factor layout transformation.
CUTLASS NVFP4 scale factor layout — reference documentation.
CUTLASS's Sm1xxBlockScaledConfig expects scale factors in a specific
interleaved layout (not simple row-major). This module transforms
our simple (M, K//16) float8_e4m3fn scales into CUTLASS's expected layout.
interleaved layout (not simple row-major). The layout is defined by:
The layout is defined by:
SfAtom = Shape<Shape<_32, _4>, Shape<SFVecSize, _4>>
with Stride<Stride<_16, _4>, Stride<_0, _1>>
(for K-major, SFVecSize=16)
(SFVecSize=16 for NVFP4 UE4M3 block-16)
layout_SFA = tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>)
layout_SFB = tile_to_shape(SfAtom{}, make_shape(N, K), Step<_2, _1>)
This creates a 2D layout where scale factors are interleaved in blocks
of 32 rows × (16*4) = 64 columns, with specific stride patterns.
The actual remap from row-major → CUTLASS interleaved layout happens
in the CUDA kernel (remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu),
NOT in Python. This file exists for reference only.
The CUDA remap uses cute::idx2crd() to invert the CUTLASS layout:
for each linear index in the CUTLASS layout, it computes the logical
(m, k) coordinate and reads from the corresponding row-major position.
"""
import torch
import math
def compute_cutlass_sf_size(M: int, K: int, sf_vec_size: int = 16) -> int:
"""Compute the number of scale factor elements needed for CUTLASS layout.
The CUTLASS layout tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>)
produces a layout with a specific size that may differ from M * (K // sf_vec_size).
"""
# SfAtom shape: (32, 4, sf_vec_size, 4) = (32*4, sf_vec_size*4) = (128, 64) for sf_vec_size=16
# tile_to_shape tiles this atom to cover (M, K) with Step<_2, _1>
# The resulting size = ceil(M / 128) * 128 * ceil(K / 64) * 64 / sf_vec_size...
# Actually it's simpler: the layout size = the total number of elements in the tiled layout
# For SfAtom with SFVecSize=16:
# SfAtom has shape (32, 4, 16, 4) which is a 2D layout of shape (128, 64)
# (after composition). The actual element count per tile = 128 * 4 = 512
# (the K dimension is sf_vec_size * 4 = 64, but only 4 "columns" in the SF atom)
# Actually, let's just compute it empirically. The size of the layout for (M, K) is:
# For SfAtom: Shape<(32,4), (16,4)> = (128, 64) in logical space
# Each atom has 128 * 4 = 512 elements (the (32,4) × (4) in the first dim groups)
# Wait, that's wrong. Let me think about this differently.
# The SfAtom for K-major, SFVecSize=16:
# Shape<Shape<32,4>, Shape<16,4>> with Stride<Stride<16,4>, Stride<0,1>>
# This is a 2D layout. The shape is (32*4, 16*4) = (128, 64) in terms of
# logical coordinates, but the actual number of elements = 32*4*16*4 = 8192? No.
#
# Actually: Shape<Shape<32,4>, Shape<16,4>> is a nested shape.
# Flattened: 32 * 4 * 16 * 4 = 8192 elements per atom? That can't be right.
#
# No - this is a 2D layout. The shape is (32*4, 16*4) = (128, 64).
# Number of elements = 128 * 64 = 8192 per atom? Still too many.
#
# Actually, the CuTe layout system works differently.
# Shape<Shape<32,4>, Shape<16,4>> has 32*4 = 128 elements in the first mode
# and 16*4 = 64 elements in the second mode. Total = 128 * 64 = 8192? No.
#
# A CuTe Layout<Shape, Stride> maps logical indices to memory offsets.
# For SfAtom: the total number of elements = product of all shape components
# = 32 * 4 * 16 * 4 = 8192.
# But that's per "atom". The tile_to_shape function tiles this.
#
# For the full layout with (M, K):
# size = ceil(M/128) * 128 * ceil(K/64) * 64 ... no.
# Actually the number of SF elements should be roughly M * (K / SFVecSize)
# = M * K / 16. The interleaving just changes the stride pattern, not the count.
#
# The total number of elements in the layout is the product of the shape.
# tile_to_shape preserves the element count (it's just a reshaping/tiling).
#
# For simple (M, K) with SFVecSize=16:
# Total SF elements = M * (K // SFVecSize) = M * (K // 16)
# But CUTLASS may pad to multiples of 128 (rows) and 64 (cols)
M_pad = math.ceil(M / 128) * 128
K_pad = math.ceil(K / 64) * 64 # 64 = SFVecSize * 4
return M_pad * (K_pad // sf_vec_size)
def transform_scales_to_cutlass_layout(
sf: torch.Tensor, # (..., M, K//16) float8_e4m3fn, simple row-major
M: int,
K: int,
sf_vec_size: int = 16,
) -> torch.Tensor:
"""Transform scale factors from simple row-major to CUTLASS interleaved layout.
CUTLASS expects scale factors in the Sm1xxBlockScaledConfig layout:
- SfAtom = Shape<Shape<32,4>, Shape<16,4>> with Stride<Stride<16,4>, Stride<0,1>>
- tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>)
This produces an interleaved layout where:
- Scales are grouped in blocks of 128 rows × 4 "scale columns"
- Within each block, 32 groups of 4 interleaved with stride pattern
- The K dimension maps to groups of sf_vec_size consecutive scales
For our purposes, the simplest approach is to allocate a buffer of the
right size and let CUTLASS's own layout computation handle the indexing.
Since the total number of elements is the same (just reordered), we can
compute the CUTLASS layout on the Python side and remap.
For now, we use a simpler approach: pass the scale factors in their
natural layout and adjust the CUTLASS kernel to use a compatible stride.
"""
# The CUTLASS block-scaled GEMM reads scales using TMA descriptors
# that are initialized with the layout_SFA/layout_SFB computed by
# Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA.
#
# The key insight: the CUTLASS 72b example fills scales using
# a cute::Tensor with this layout. If we provide data in a different
# layout, the TMA loads will read wrong addresses.
#
# However, for the _single_ expert GEMM (not grouped), CUTLASS handles
# the TMA descriptor setup internally based on the stride/layout we pass.
# If we pass our scales with a simple row-major stride, CUTLASS will
# still try to read them using the interleaved layout — which is wrong.
#
# The fix: we need to either:
# 1. Remap our data to match the CUTLASS layout (complex)
# 2. Or use a different CUTLASS API that accepts custom strides
#
# For now, let's compute the CUTLASS layout size, allocate, and
# do a simple remap. The remap is: for each (row, k_group) in the
# CUTLASS layout, write the scale factor from our simple layout.
# For SFVecSize=16, the SfAtom has:
# Shape<(32,4), (16,4)> with Stride<(16,4), (0,1)>
#
# This means for the K-major atom:
# - Row index (first mode): 32 groups of 4, stride (16, 4)
# So element (i, j) in (32, 4) maps to offset i*16 + j*4
# - Col index (second mode): 16 groups of 4, stride (0, 1)
# So element (k, l) in (16, 4) maps to offset k*0 + l*1 = l
#
# Combined: offset = (i*16 + j*4) + l * 128
# (because the second mode stride (0, 1) means col dimension
# doesn't contribute to the first mode and vice versa)
#
# Actually, for a 2D CuTe layout:
# Shape<(32,4), (16,4)> = (128, 64)
# Stride<(16,4), (0,1)> means:
# First mode (row): stride 16, 4 -> 32*4=128 rows, row stride = (16,4)
# Second mode (col): stride 0, 1 -> 16*4=64 cols, col stride = (0,1)
#
# For element at (r, c) where r = i*4+j (i in 0..31, j in 0..3)
# c = k*4+l (k in 0..15, l in 0..3):
# offset = i*16 + j*4 + l*128
#
# Wait, that's not right either. Let me just compute the layout properly.
# The total number of SF elements should be M * (K // 16) for our data.
# For now, let's just return the scales as-is and see if CUTLASS can handle
# a simple stride. The real fix would be to use CUTLASS's own layout helpers.
return sf