docs: clarify SF layout remap is in CUDA, not sf_layout.py
sf_layout.py was a no-op (return sf) but the actual remap happens in remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu. Updated sf_layout.py to pure reference docs so nobody gets confused again.
This commit is contained in:
@@ -1,156 +1,21 @@
|
||||
"""
|
||||
CUTLASS NVFP4 scale factor layout transformation.
|
||||
CUTLASS NVFP4 scale factor layout — reference documentation.
|
||||
|
||||
CUTLASS's Sm1xxBlockScaledConfig expects scale factors in a specific
|
||||
interleaved layout (not simple row-major). This module transforms
|
||||
our simple (M, K//16) float8_e4m3fn scales into CUTLASS's expected layout.
|
||||
interleaved layout (not simple row-major). The layout is defined by:
|
||||
|
||||
The layout is defined by:
|
||||
SfAtom = Shape<Shape<_32, _4>, Shape<SFVecSize, _4>>
|
||||
with Stride<Stride<_16, _4>, Stride<_0, _1>>
|
||||
(for K-major, SFVecSize=16)
|
||||
(SFVecSize=16 for NVFP4 UE4M3 block-16)
|
||||
|
||||
layout_SFA = tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>)
|
||||
layout_SFB = tile_to_shape(SfAtom{}, make_shape(N, K), Step<_2, _1>)
|
||||
|
||||
This creates a 2D layout where scale factors are interleaved in blocks
|
||||
of 32 rows × (16*4) = 64 columns, with specific stride patterns.
|
||||
The actual remap from row-major → CUTLASS interleaved layout happens
|
||||
in the CUDA kernel (remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu),
|
||||
NOT in Python. This file exists for reference only.
|
||||
|
||||
The CUDA remap uses cute::idx2crd() to invert the CUTLASS layout:
|
||||
for each linear index in the CUTLASS layout, it computes the logical
|
||||
(m, k) coordinate and reads from the corresponding row-major position.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
def compute_cutlass_sf_size(M: int, K: int, sf_vec_size: int = 16) -> int:
|
||||
"""Compute the number of scale factor elements needed for CUTLASS layout.
|
||||
|
||||
The CUTLASS layout tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>)
|
||||
produces a layout with a specific size that may differ from M * (K // sf_vec_size).
|
||||
"""
|
||||
# SfAtom shape: (32, 4, sf_vec_size, 4) = (32*4, sf_vec_size*4) = (128, 64) for sf_vec_size=16
|
||||
# tile_to_shape tiles this atom to cover (M, K) with Step<_2, _1>
|
||||
# The resulting size = ceil(M / 128) * 128 * ceil(K / 64) * 64 / sf_vec_size...
|
||||
# Actually it's simpler: the layout size = the total number of elements in the tiled layout
|
||||
|
||||
# For SfAtom with SFVecSize=16:
|
||||
# SfAtom has shape (32, 4, 16, 4) which is a 2D layout of shape (128, 64)
|
||||
# (after composition). The actual element count per tile = 128 * 4 = 512
|
||||
# (the K dimension is sf_vec_size * 4 = 64, but only 4 "columns" in the SF atom)
|
||||
|
||||
# Actually, let's just compute it empirically. The size of the layout for (M, K) is:
|
||||
# For SfAtom: Shape<(32,4), (16,4)> = (128, 64) in logical space
|
||||
# Each atom has 128 * 4 = 512 elements (the (32,4) × (4) in the first dim groups)
|
||||
# Wait, that's wrong. Let me think about this differently.
|
||||
|
||||
# The SfAtom for K-major, SFVecSize=16:
|
||||
# Shape<Shape<32,4>, Shape<16,4>> with Stride<Stride<16,4>, Stride<0,1>>
|
||||
# This is a 2D layout. The shape is (32*4, 16*4) = (128, 64) in terms of
|
||||
# logical coordinates, but the actual number of elements = 32*4*16*4 = 8192? No.
|
||||
#
|
||||
# Actually: Shape<Shape<32,4>, Shape<16,4>> is a nested shape.
|
||||
# Flattened: 32 * 4 * 16 * 4 = 8192 elements per atom? That can't be right.
|
||||
#
|
||||
# No - this is a 2D layout. The shape is (32*4, 16*4) = (128, 64).
|
||||
# Number of elements = 128 * 64 = 8192 per atom? Still too many.
|
||||
#
|
||||
# Actually, the CuTe layout system works differently.
|
||||
# Shape<Shape<32,4>, Shape<16,4>> has 32*4 = 128 elements in the first mode
|
||||
# and 16*4 = 64 elements in the second mode. Total = 128 * 64 = 8192? No.
|
||||
#
|
||||
# A CuTe Layout<Shape, Stride> maps logical indices to memory offsets.
|
||||
# For SfAtom: the total number of elements = product of all shape components
|
||||
# = 32 * 4 * 16 * 4 = 8192.
|
||||
# But that's per "atom". The tile_to_shape function tiles this.
|
||||
#
|
||||
# For the full layout with (M, K):
|
||||
# size = ceil(M/128) * 128 * ceil(K/64) * 64 ... no.
|
||||
# Actually the number of SF elements should be roughly M * (K / SFVecSize)
|
||||
# = M * K / 16. The interleaving just changes the stride pattern, not the count.
|
||||
#
|
||||
# The total number of elements in the layout is the product of the shape.
|
||||
# tile_to_shape preserves the element count (it's just a reshaping/tiling).
|
||||
#
|
||||
# For simple (M, K) with SFVecSize=16:
|
||||
# Total SF elements = M * (K // SFVecSize) = M * (K // 16)
|
||||
# But CUTLASS may pad to multiples of 128 (rows) and 64 (cols)
|
||||
M_pad = math.ceil(M / 128) * 128
|
||||
K_pad = math.ceil(K / 64) * 64 # 64 = SFVecSize * 4
|
||||
return M_pad * (K_pad // sf_vec_size)
|
||||
|
||||
|
||||
def transform_scales_to_cutlass_layout(
|
||||
sf: torch.Tensor, # (..., M, K//16) float8_e4m3fn, simple row-major
|
||||
M: int,
|
||||
K: int,
|
||||
sf_vec_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""Transform scale factors from simple row-major to CUTLASS interleaved layout.
|
||||
|
||||
CUTLASS expects scale factors in the Sm1xxBlockScaledConfig layout:
|
||||
- SfAtom = Shape<Shape<32,4>, Shape<16,4>> with Stride<Stride<16,4>, Stride<0,1>>
|
||||
- tile_to_shape(SfAtom{}, make_shape(M, K), Step<_2, _1>)
|
||||
|
||||
This produces an interleaved layout where:
|
||||
- Scales are grouped in blocks of 128 rows × 4 "scale columns"
|
||||
- Within each block, 32 groups of 4 interleaved with stride pattern
|
||||
- The K dimension maps to groups of sf_vec_size consecutive scales
|
||||
|
||||
For our purposes, the simplest approach is to allocate a buffer of the
|
||||
right size and let CUTLASS's own layout computation handle the indexing.
|
||||
|
||||
Since the total number of elements is the same (just reordered), we can
|
||||
compute the CUTLASS layout on the Python side and remap.
|
||||
|
||||
For now, we use a simpler approach: pass the scale factors in their
|
||||
natural layout and adjust the CUTLASS kernel to use a compatible stride.
|
||||
"""
|
||||
# The CUTLASS block-scaled GEMM reads scales using TMA descriptors
|
||||
# that are initialized with the layout_SFA/layout_SFB computed by
|
||||
# Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA.
|
||||
#
|
||||
# The key insight: the CUTLASS 72b example fills scales using
|
||||
# a cute::Tensor with this layout. If we provide data in a different
|
||||
# layout, the TMA loads will read wrong addresses.
|
||||
#
|
||||
# However, for the _single_ expert GEMM (not grouped), CUTLASS handles
|
||||
# the TMA descriptor setup internally based on the stride/layout we pass.
|
||||
# If we pass our scales with a simple row-major stride, CUTLASS will
|
||||
# still try to read them using the interleaved layout — which is wrong.
|
||||
#
|
||||
# The fix: we need to either:
|
||||
# 1. Remap our data to match the CUTLASS layout (complex)
|
||||
# 2. Or use a different CUTLASS API that accepts custom strides
|
||||
#
|
||||
# For now, let's compute the CUTLASS layout size, allocate, and
|
||||
# do a simple remap. The remap is: for each (row, k_group) in the
|
||||
# CUTLASS layout, write the scale factor from our simple layout.
|
||||
|
||||
# For SFVecSize=16, the SfAtom has:
|
||||
# Shape<(32,4), (16,4)> with Stride<(16,4), (0,1)>
|
||||
#
|
||||
# This means for the K-major atom:
|
||||
# - Row index (first mode): 32 groups of 4, stride (16, 4)
|
||||
# So element (i, j) in (32, 4) maps to offset i*16 + j*4
|
||||
# - Col index (second mode): 16 groups of 4, stride (0, 1)
|
||||
# So element (k, l) in (16, 4) maps to offset k*0 + l*1 = l
|
||||
#
|
||||
# Combined: offset = (i*16 + j*4) + l * 128
|
||||
# (because the second mode stride (0, 1) means col dimension
|
||||
# doesn't contribute to the first mode and vice versa)
|
||||
#
|
||||
# Actually, for a 2D CuTe layout:
|
||||
# Shape<(32,4), (16,4)> = (128, 64)
|
||||
# Stride<(16,4), (0,1)> means:
|
||||
# First mode (row): stride 16, 4 -> 32*4=128 rows, row stride = (16,4)
|
||||
# Second mode (col): stride 0, 1 -> 16*4=64 cols, col stride = (0,1)
|
||||
#
|
||||
# For element at (r, c) where r = i*4+j (i in 0..31, j in 0..3)
|
||||
# c = k*4+l (k in 0..15, l in 0..3):
|
||||
# offset = i*16 + j*4 + l*128
|
||||
#
|
||||
# Wait, that's not right either. Let me just compute the layout properly.
|
||||
# The total number of SF elements should be M * (K // 16) for our data.
|
||||
|
||||
# For now, let's just return the scales as-is and see if CUTLASS can handle
|
||||
# a simple stride. The real fix would be to use CUTLASS's own layout helpers.
|
||||
return sf
|
||||
|
||||
Reference in New Issue
Block a user