fix: rewrite SF remap as forward mapping (source→dst)

- Iterate over source indices (MN * K_sf) instead of dst indices
- Use layout_sf forward mapping: layout_sf(make_coord(mn, k_sf*16))
- No more idx2crd reverse extraction or stride-0 ambiguity
- Cleaner, less error-prone, blog-compatible
This commit is contained in:
2026-05-15 20:51:30 +00:00
parent 30b6c89424
commit 63e67e1025

View File

@@ -124,39 +124,40 @@ __global__ void remap_sf_to_cutlass_kernel(
int MN, int K_sf, // Source dimensions (in SF groups)
bool col_major_src = false // true if source is (K_sf, MN) row-major
) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = cute::cosize(layout_sf);
if (dst_idx >= total) return;
// Forward-mapping approach: iterate over source indices (mn, k_sf),
// compute the CUTLASS dst index via layout_sf forward mapping.
// k_sf * InputSFVectorSize converts from SF-group index to logical K element
// coordinate, which is what the layout expects.
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int total = MN * K_sf;
if (tid >= total) return;
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
auto flat = cute::flatten(coord);
int mn, k_sf_val, src_idx;
constexpr int R = cute::rank_v<decltype(flat)>;
int mn = 0;
int k_sf = 0;
if constexpr (R >= 6) {
// K-major SF atom: Shape<Shape<_32,_4>, Shape<SFVecSize,_4>>
// Stride<Stride<_16,_4>, Stride<_0,_1>>
// SFA tiled as make_shape(M,K,...), SFB as make_shape(N,K,...)
// First flattened group is M/N: ((32,4), tile_mn)
mn =
int(cute::get<0>(flat)) +
32 * int(cute::get<1>(flat)) +
128 * int(cute::get<2>(flat));
// Second flattened group is K: (SFVecSize, 4, tile_k)
// get<3> is the stride-0 k-within-SF-vector coordinate — ignore it.
k_sf =
int(cute::get<4>(flat)) +
4 * int(cute::get<5>(flat));
if (col_major_src) {
// source is row-major (K_sf, MN), e.g. SFB stored as (K_sf, N)
k_sf_val = tid / MN;
mn = tid % MN;
src_idx = tid;
} else {
// source is row-major (MN, K_sf), e.g. SFA stored as (M, K_sf)
mn = tid / K_sf;
k_sf_val = tid % K_sf;
src_idx = tid;
}
if (mn < MN && k_sf < K_sf) {
dst[dst_idx] = col_major_src
? src[k_sf * MN + mn] // SFB source: (K_sf, N)
: src[mn * K_sf + k_sf]; // SFA source: (M, K_sf)
// Use layout forward mapping: source (mn, k_sf*16) -> dst_idx
constexpr int LayoutRank = cute::rank_v<decltype(layout_sf.shape())>;
int dst_idx = 0;
if constexpr (LayoutRank == 2) {
dst_idx = layout_sf(cute::make_coord(mn, k_sf_val * InputSFVectorSize));
} else if constexpr (LayoutRank == 3) {
dst_idx = layout_sf(cute::make_coord(mn, k_sf_val * InputSFVectorSize, 0));
}
if (dst_idx >= 0 && dst_idx < cute::cosize(layout_sf)) {
dst[dst_idx] = src[src_idx];
}
}
@@ -197,9 +198,11 @@ int cutlass_nvfp4_gemm_run(
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
int block = 256;
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
int sfa_src_total = M * K_sf;
int sfb_src_total = N * K_sf;
remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false);
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, true);
typename Gemm::Arguments arguments {
@@ -260,7 +263,8 @@ extern "C" int cutlass_nvfp4_prepack_sfb_run(
cudaMemsetAsync(static_cast<ElementSF*>(SFB_cutlass_ptr), 0, sfb_size * sizeof(ElementSF), stream);
int block = 256;
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
int sfb_src_total = N * K_sf;
remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr),
static_cast<ElementSF*>(SFB_cutlass_ptr),
layout_SFB,
@@ -303,7 +307,8 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb(
cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream);
int block = 256;
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
int sfa_src_total = M * K_sf;
remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false);
typename Gemm::Arguments arguments {