diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 2a291ba7..6b4eacee 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -178,32 +178,33 @@ int cutlass_nvfp4_gemm_run( LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + // For now, pass scale factors directly (no remap) to test if the kernel works + // The CUTLASS layout may expect interleaved data, but let's see what happens + // with simple row-major. If TMA reads garbage, we'll get NaN or wrong values + // rather than crashes (since we verified small sizes work). + // + // TODO: implement proper interleaved layout remap + + // Temporary: allocate CUTLASS-layout buffers and do a simple copy + int K_sf = K / InputSFVectorSize; + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - int sfa_size = cute::size(layout_SFA); - int sfb_size = cute::size(layout_SFB); - int K_sf = K / InputSFVectorSize; - - cutlass::device_memory::allocation sfa_cutlass(sfa_size); - cutlass::device_memory::allocation sfb_cutlass(sfb_size); - - // Remap scales from simple row-major to CUTLASS interleaved layout - int block = 256; - remap_sf_to_cutlass_layout_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( - static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf); - remap_sf_to_cutlass_layout_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( - static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf); - + // Pass scale factors directly — the layout_SFA/layout_SFB tell CUTLASS + // how to index into the data. Our data is (MN, K//16) row-major but + // CUTLASS expects its own interleaved layout. + // For initial testing, pass directly and check for NaN/wrong values. + typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K, 1}, { static_cast(A_ptr), stride_A, static_cast(B_ptr), stride_B, - sfa_cutlass.get(), layout_SFA, - sfb_cutlass.get(), layout_SFB + static_cast(SFA_ptr), layout_SFA, + static_cast(SFB_ptr), layout_SFB }, { { alpha, beta },