From 494d30b6ab525a2e82e4526a2985bfdcda4e3a22 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 10:48:58 +0000 Subject: [PATCH] Fix: use CuTe get_coord for proper scale factor remap to CUTLASS interleaved layout --- .../cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu | 95 +++++++------------ 1 file changed, 36 insertions(+), 59 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 6b4eacee..cc8ec61c 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -1,8 +1,5 @@ /*************************************************************************************************** * CUTLASS NVFP4 Block-Scaled GEMM for DeepSeek-V4-Pro MoE - * - * Based on NVIDIA CUTLASS example 72b_blackwell_nvfp4_nvfp4_gemm.cu - * Uses native tcgen05.mma kind::mxf8f6f4.block_scale instructions on Blackwell SM100. **************************************************************************************************/ #pragma once @@ -36,10 +33,6 @@ using namespace cute; -///////////////////////////////////////////////////////////////////////////////////////////////// -// NVFP4 × NVFP4 → BF16 GEMM -///////////////////////////////////////////////////////////////////////////////////////////////// - using ElementA = cutlass::nv_float4_t; using LayoutATag = cutlass::layout::RowMajor; constexpr int AlignmentA = 32; @@ -101,28 +94,13 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; ///////////////////////////////////////////////////////////////////////////////////////////////// -// Scale factor remap kernel +// Scale factor remap kernel using CuTe layout operations ///////////////////////////////////////////////////////////////////////////////////////////////// -// Remap simple row-major (MN, K//16) scales to CUTLASS interleaved layout -// The CUTLASS layout is computed by Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA -// which tiles SfAtom to the problem shape. -// -// For SFVecSize=16, K-major: -// SfAtom = Shape, Shape<16,4>> -// Stride, Stride<0,1>> -// -// tile_to_shape(SfAtom{}, make_shape(M,K), Step<_2,_1>{}) produces a 2D layout -// where the first mode (row/M) is tiled with step 2 (interleaved with the second mode) -// and the second mode (col/K) is tiled with step 1. -// -// For our remap, we iterate over the CUTLASS layout indices and for each, -// compute which (row, k_group) from the source it corresponds to, then copy. - template -__global__ void remap_sf_to_cutlass_layout_kernel( - const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K//16) row-major - cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS layout +__global__ void remap_sf_to_cutlass_kernel( + const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major + cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout LayoutSF layout_sf, // CuTe layout for dst int MN, int K_sf // Source dimensions ) { @@ -130,28 +108,28 @@ __global__ void remap_sf_to_cutlass_layout_kernel( int total = cute::size(layout_sf); if (idx >= total) return; - // The CUTLASS layout maps idx -> (row_coord, k_coord) - // We need to figure out which (row, k_group) this corresponds to - // in the simple row-major source layout. + // The CUTLASS layout maps linear index -> (m, k) coordinate pair + // We need to find which (m, k) this linear index corresponds to + // and then read from our simple row-major source. // - // For the SfAtom with SFVecSize=16, K-major: - // The layout interleaves scale factors in groups of 128 rows - // and 64 K-groups (4 "tiles" of 16). + // CuTe layouts support crd(idx) which gives the coordinate for an index. + // The coordinate is in the logical space of the layout. + // For SFA: the layout maps to (M, K) where K is in SF groups + // For SFB: the layout maps to (N, K) where K is in SF groups // - // Simple approach: use the layout's coordinate function. - // The layout maps a linear index to a logical coordinate (m, k). - // We can extract (m, k) and then index into our source as m * K_sf + k. + // The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{}) + // So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor. - // Get the 2D coordinate from the layout - auto coord = layout_sf.get_flat_coord(idx); - int m = cute::get<0>(coord); // Row index in the layout's coordinate system - int k = cute::get<1>(coord); // K-group index + auto coord = layout_sf.get_coord(idx); + + // Extract the (m, k) pair from the coordinate tuple + int m = cute::get<0>(cute::flatten(coord)); // Row index + int k = cute::get<1>(cute::flatten(coord)); // K-group index - // Map to source index if (m < MN && k < K_sf) { dst[idx] = src[m * K_sf + k]; } else { - dst[idx] = cutlass::float_ue4m3_t(0); // Zero padding + dst[idx] = cutlass::float_ue4m3_t(0); } } @@ -178,33 +156,32 @@ int cutlass_nvfp4_gemm_run( LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - // For now, pass scale factors directly (no remap) to test if the kernel works - // The CUTLASS layout may expect interleaved data, but let's see what happens - // with simple row-major. If TMA reads garbage, we'll get NaN or wrong values - // rather than crashes (since we verified small sizes work). - // - // TODO: implement proper interleaved layout remap - - // Temporary: allocate CUTLASS-layout buffers and do a simple copy - int K_sf = K / InputSFVectorSize; - using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - // Pass scale factors directly — the layout_SFA/layout_SFB tell CUTLASS - // how to index into the data. Our data is (MN, K//16) row-major but - // CUTLASS expects its own interleaved layout. - // For initial testing, pass directly and check for NaN/wrong values. - + int sfa_size = cute::size(layout_SFA); + int sfb_size = cute::size(layout_SFB); + int K_sf = K / InputSFVectorSize; + + // Allocate CUTLASS-layout buffers and remap scales + cutlass::device_memory::allocation sfa_cutlass(sfa_size); + cutlass::device_memory::allocation sfb_cutlass(sfb_size); + + int block = 256; + remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( + static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf); + remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( + static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf); + typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K, 1}, { static_cast(A_ptr), stride_A, static_cast(B_ptr), stride_B, - static_cast(SFA_ptr), layout_SFA, - static_cast(SFB_ptr), layout_SFB + sfa_cutlass.get(), layout_SFA, + sfb_cutlass.get(), layout_SFB }, { { alpha, beta }, @@ -227,4 +204,4 @@ int cutlass_nvfp4_gemm_run( } // extern "C" -#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED +#endif