diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 0cdc82c6..fb3e05cf 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -93,12 +93,17 @@ using StrideD = typename Gemm::GemmKernel::StrideD; using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; -///////////////////////////////////////////////////////////////////////////////////////////////// -// Scale factor remap kernel using CuTe layout operations -///////////////////////////////////////////////////////////////////////////////////////////////// - ///////////////////////////////////////////////////////////////////////////////////////////////// // Scale factor remap kernel: row-major source -> CUTLASS interleaved layout +// +// Iterates over source (m, k_group) indices and uses layout_sf(m, k_elem) to +// compute the CUTLASS destination offset. CuTe's layout operator() decomposes +// flat coordinates into nested sub-coordinates automatically — no idx2crd, +// no flatten, no rank inspection. +// +// K is in element-space for the layout: k_group * SFVecSize. +// Iterating groups (not every element) is efficient since all 16 elements +// within a group share one SF byte. ///////////////////////////////////////////////////////////////////////////////////////////////// template @@ -109,17 +114,6 @@ __global__ void remap_sf_to_cutlass_kernel( int MN, int K_sf, // Source dimensions (in SF groups) int SFVecSize // Elements per SF group (16 for NVFP4) ) { - // SOURCE-ITERATING: for each (m, k_group) in the row-major source, - // compute the CUTLASS destination offset using layout_sf(m, k_elem). - // - // CuTe's layout operator() accepts "natural" coordinates whose rank matches - // the top-level mode count (rank-2 here: M and K). It decomposes flat ints - // into nested sub-coordinates automatically. No idx2crd, no flatten, no rank inspection. - // - // The K coordinate is in ELEMENT-SPACE (not group-space): - // k_group=3 with SFVecSize=16 -> k_elem=48 in the layout. - // All 16 elements within a group share the same SF byte, so we iterate - // over groups (not elements) for efficiency — one write per SF value. int src_idx = blockIdx.x * blockDim.x + threadIdx.x; int total = MN * K_sf; if (src_idx >= total) return; @@ -128,14 +122,8 @@ __global__ void remap_sf_to_cutlass_kernel( int k_group = src_idx % K_sf; int k_elem = k_group * SFVecSize; - // layout_sf(m, k_elem) -> CUTLASS linear index - // CuTe handles the nested shape decomposition internally. dst[layout_sf(m, k_elem)] = src[src_idx]; } - if (m < MN && k_sf_out < K_sf) { - dst[dst_idx] = src[m * K_sf + k_sf_out]; - } -} ///////////////////////////////////////////////////////////////////////////////////////////////// // C API @@ -169,16 +157,12 @@ int cutlass_nvfp4_gemm_run( int K_sf = K / InputSFVectorSize; // Allocate CUTLASS-layout buffers, zero-init, and remap scales - // Zero-init is critical: CUTLASS pads to tile boundaries (128x64), - // so dest is larger than M*K_sf. Unmapped slots must be zero to avoid - // garbage values in the GEMM. cutlass::device_memory::allocation sfa_cutlass(sfa_size); cutlass::device_memory::allocation sfb_cutlass(sfb_size); cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream); cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream); int block = 256; - // Grid size based on SOURCE elements (MN * K_sf) since kernel iterates source int sfa_src = M * K_sf; int sfb_src = N * K_sf; remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(