fix: rewrite SF remap kernel — source-iterating with layout_sf(m, k_elem)

Ripped out idx2crd + flatten + get<> approach entirely. New kernel
iterates over source indices (m, k_group) and uses layout_sf(m, k_elem)
to compute the CUTLASS destination offset. CuTe handles nested shape
decomposition internally — no rank inspection needed.

K coordinate is in element-space (k_group * SFVecSize) as the layout
expects. Iterates over groups (not every element) since all 16 elements
within a group share one SF byte — avoids 16x redundant writes.

Grid size based on source count (MN * K_sf), not dest buffer size.
This commit is contained in:
2026-05-14 15:28:44 +00:00
parent fb390b24e2
commit 196ee37fdb

View File

@@ -98,59 +98,40 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scale factor remap kernel using CuTe layout operations
// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout
/////////////////////////////////////////////////////////////////////////////////////////////////
template<typename LayoutSF>
__global__ void remap_sf_to_cutlass_kernel(
const cutlass::float_ue4m3_t* __restrict__ src,
cutlass::float_ue4m3_t* __restrict__ dst,
LayoutSF layout_sf,
int MN, int K_sf
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized)
LayoutSF layout_sf, // CuTe layout for dst
int MN, int K_sf, // Source dimensions (in SF groups)
int SFVecSize // Elements per SF group (16 for NVFP4)
) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = cute::size(layout_sf);
if (dst_idx >= total) return;
// SOURCE-ITERATING: for each (m, k_group) in the row-major source,
// compute the CUTLASS destination offset using layout_sf(m, k_elem).
//
// CuTe's layout operator() accepts "natural" coordinates whose rank matches
// the top-level mode count (rank-2 here: M and K). It decomposes flat ints
// into nested sub-coordinates automatically. No idx2crd, no flatten, no rank inspection.
//
// The K coordinate is in ELEMENT-SPACE (not group-space):
// k_group=3 with SFVecSize=16 -> k_elem=48 in the layout.
// All 16 elements within a group share the same SF byte, so we iterate
// over groups (not elements) for efficiency — one write per SF value.
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = MN * K_sf;
if (src_idx >= total) return;
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
auto flat = cute::flatten(coord);
constexpr int flat_rank = cute::rank_v<decltype(flat)>;
// Debug: print flat_rank and first few coordinates (only for idx 0)
if (dst_idx == 0) {
printf("[remap] flat_rank=%d, MN=%d, K_sf=%d, total=%d\n", flat_rank, MN, K_sf, total);
printf("[remap] layout shape: ");
cute::print(layout_sf.shape());
printf("\n");
}
int m, k_sf_out;
if constexpr (flat_rank == 6) {
// Full nested: (inner_M, sub_M, tile_M, inner_K, sub_K, tile_K)
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int tile_m = cute::get<2>(flat);
// get<3> = inner_k (within one SF group — same byte)
int sub_k = cute::get<4>(flat);
int tile_k = cute::get<5>(flat);
m = tile_m * 128 + inner_m * 4 + sub_m;
k_sf_out = tile_k * 4 + sub_k;
} else if constexpr (flat_rank == 4) {
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int sub_k = cute::get<2>(flat);
int tile_k = cute::get<3>(flat);
m = inner_m * 4 + sub_m;
k_sf_out = tile_k * 4 + sub_k;
} else {
m = cute::get<0>(flat);
k_sf_out = cute::get<1>(flat);
}
int m = src_idx / K_sf;
int k_group = src_idx % K_sf;
int k_elem = k_group * SFVecSize;
// layout_sf(m, k_elem) -> CUTLASS linear index
// CuTe handles the nested shape decomposition internally.
dst[layout_sf(m, k_elem)] = src[src_idx];
}
if (m < MN && k_sf_out < K_sf) {
dst[dst_idx] = src[m * K_sf + k_sf_out];
}
@@ -197,11 +178,13 @@ int cutlass_nvfp4_gemm_run(
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
int block = 256;
// Grid size based on DEST (CUTLASS buffer) size since kernel iterates dest
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
// Grid size based on SOURCE elements (MN * K_sf) since kernel iterates source
int sfa_src = M * K_sf;
int sfb_src = N * K_sf;
remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize);
remap_sf_to_cutlass_kernel<<<(sfb_src + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize);
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,