fix: SF remap was using idx2crd+flatten which gives atom sub-indices, not logical (m,k)

The remap kernel iterated over CUTLASS linear indices and tried to
reverse-map with idx2crd + flatten. But flatten() on the nested CuTe
coordinate (from tile_to_shape(SfAtom{}, ...)) gives atom-level
sub-indices, not logical (m, k). This caused all K-groups > 0 in SFA
to map to m*K_sf+0, losing K-group information entirely.

Proof: setting SFA[0,0]=2.0 changed row 0, but SFA[0,3]=2.0 produced
zero change. Only K-group 0 was being read.

Fix: iterate over SOURCE indices (row-major m, k) and use the CuTe
layout forward: layout_sf(make_coord(m, k)) -> CUTLASS dst index.
This is the correct forward direction that CuTe handles natively.

Constant-scale test (all SF=1.0) gave cosine=1.0, confirming the FP4
data path is correct. The bug was purely in the SF remap.
This commit is contained in:
2026-05-14 14:51:02 +00:00
parent cf796e37cf
commit 5968ebad9f

View File

@@ -104,36 +104,27 @@ __global__ void remap_sf_to_cutlass_kernel(
LayoutSF layout_sf, // CuTe layout for dst
int MN, int K_sf // Source dimensions
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = cute::size(layout_sf);
if (idx >= total) return;
// The CUTLASS layout maps linear index -> (m, k) coordinate pair
// We need to find which (m, k) this linear index corresponds to
// and then read from our simple row-major source.
// Iterate over SOURCE indices (row-major) and write to CUTLASS destination.
// The layout maps logical (m, k) -> CUTLASS linear index.
// This is the forward direction, which CuTe handles correctly.
//
// CuTe layouts support crd(idx) which gives the coordinate for an index.
// The coordinate is in the logical space of the layout.
// For SFA: the layout maps to (M, K) where K is in SF groups
// For SFB: the layout maps to (N, K) where K is in SF groups
//
// The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{})
// So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor.
// Previous approach (iterate over CUTLASS idx, reverse-map with idx2crd+flatten)
// was broken: flatten() on the nested CuTe coordinate gives atom sub-indices,
// not logical (m, k). This caused K-group > 0 to always map to m*K_sf+0,
// losing all K-group information in the SFA.
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = MN * K_sf;
if (src_idx >= total) return;
// idx2crd converts a linear index to a logical coordinate in the layout's space
auto coord = cute::idx2crd(idx, layout_sf.shape(), layout_sf.stride());
int m = src_idx / K_sf;
int k = src_idx % K_sf;
// The coordinate is a nested tuple. For a 2D layout it's (c0, c1)
// where c0 = row/M index and c1 = col/K-group index.
// Flatten the nested tuple to extract the two logical coordinates.
auto flat = cute::flatten(coord);
int m = cute::get<0>(flat);
int k = cute::get<1>(flat);
// Use the CuTe layout to find the destination index for this (m, k)
// layout_sf(m, k) returns the linear index in CUTLASS's expected layout
auto dst_idx = layout_sf(cute::make_coord(m, k));
if (m < MN && k < K_sf) {
dst[idx] = src[m * K_sf + k];
} else {
dst[idx] = cutlass::float_ue4m3_t(0);
if (dst_idx < cute::size(layout_sf)) {
dst[dst_idx] = src[src_idx];
}
}