fix: SF remap was using idx2crd+flatten which gives atom sub-indices, not logical (m,k)
The remap kernel iterated over CUTLASS linear indices and tried to
reverse-map with idx2crd + flatten. But flatten() on the nested CuTe
coordinate (from tile_to_shape(SfAtom{}, ...)) gives atom-level
sub-indices, not logical (m, k). This caused all K-groups > 0 in SFA
to map to m*K_sf+0, losing K-group information entirely.
Proof: setting SFA[0,0]=2.0 changed row 0, but SFA[0,3]=2.0 produced
zero change. Only K-group 0 was being read.
Fix: iterate over SOURCE indices (row-major m, k) and use the CuTe
layout forward: layout_sf(make_coord(m, k)) -> CUTLASS dst index.
This is the correct forward direction that CuTe handles natively.
Constant-scale test (all SF=1.0) gave cosine=1.0, confirming the FP4
data path is correct. The bug was purely in the SF remap.
This commit is contained in:
@@ -104,36 +104,27 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
LayoutSF layout_sf, // CuTe layout for dst
|
||||
int MN, int K_sf // Source dimensions
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = cute::size(layout_sf);
|
||||
if (idx >= total) return;
|
||||
|
||||
// The CUTLASS layout maps linear index -> (m, k) coordinate pair
|
||||
// We need to find which (m, k) this linear index corresponds to
|
||||
// and then read from our simple row-major source.
|
||||
// Iterate over SOURCE indices (row-major) and write to CUTLASS destination.
|
||||
// The layout maps logical (m, k) -> CUTLASS linear index.
|
||||
// This is the forward direction, which CuTe handles correctly.
|
||||
//
|
||||
// CuTe layouts support crd(idx) which gives the coordinate for an index.
|
||||
// The coordinate is in the logical space of the layout.
|
||||
// For SFA: the layout maps to (M, K) where K is in SF groups
|
||||
// For SFB: the layout maps to (N, K) where K is in SF groups
|
||||
//
|
||||
// The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{})
|
||||
// So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor.
|
||||
// Previous approach (iterate over CUTLASS idx, reverse-map with idx2crd+flatten)
|
||||
// was broken: flatten() on the nested CuTe coordinate gives atom sub-indices,
|
||||
// not logical (m, k). This caused K-group > 0 to always map to m*K_sf+0,
|
||||
// losing all K-group information in the SFA.
|
||||
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = MN * K_sf;
|
||||
if (src_idx >= total) return;
|
||||
|
||||
// idx2crd converts a linear index to a logical coordinate in the layout's space
|
||||
auto coord = cute::idx2crd(idx, layout_sf.shape(), layout_sf.stride());
|
||||
int m = src_idx / K_sf;
|
||||
int k = src_idx % K_sf;
|
||||
|
||||
// The coordinate is a nested tuple. For a 2D layout it's (c0, c1)
|
||||
// where c0 = row/M index and c1 = col/K-group index.
|
||||
// Flatten the nested tuple to extract the two logical coordinates.
|
||||
auto flat = cute::flatten(coord);
|
||||
int m = cute::get<0>(flat);
|
||||
int k = cute::get<1>(flat);
|
||||
// Use the CuTe layout to find the destination index for this (m, k)
|
||||
// layout_sf(m, k) returns the linear index in CUTLASS's expected layout
|
||||
auto dst_idx = layout_sf(cute::make_coord(m, k));
|
||||
|
||||
if (m < MN && k < K_sf) {
|
||||
dst[idx] = src[m * K_sf + k];
|
||||
} else {
|
||||
dst[idx] = cutlass::float_ue4m3_t(0);
|
||||
if (dst_idx < cute::size(layout_sf)) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user