Fix: use CuTe get_coord for proper scale factor remap to CUTLASS interleaved layout
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* CUTLASS NVFP4 Block-Scaled GEMM for DeepSeek-V4-Pro MoE
|
||||
*
|
||||
* Based on NVIDIA CUTLASS example 72b_blackwell_nvfp4_nvfp4_gemm.cu
|
||||
* Uses native tcgen05.mma kind::mxf8f6f4.block_scale instructions on Blackwell SM100.
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
@@ -36,10 +33,6 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// NVFP4 × NVFP4 → BF16 GEMM
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 32;
|
||||
@@ -101,28 +94,13 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Scale factor remap kernel
|
||||
// Scale factor remap kernel using CuTe layout operations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Remap simple row-major (MN, K//16) scales to CUTLASS interleaved layout
|
||||
// The CUTLASS layout is computed by Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA
|
||||
// which tiles SfAtom to the problem shape.
|
||||
//
|
||||
// For SFVecSize=16, K-major:
|
||||
// SfAtom = Shape<Shape<32,4>, Shape<16,4>>
|
||||
// Stride<Stride<16,4>, Stride<0,1>>
|
||||
//
|
||||
// tile_to_shape(SfAtom{}, make_shape(M,K), Step<_2,_1>{}) produces a 2D layout
|
||||
// where the first mode (row/M) is tiled with step 2 (interleaved with the second mode)
|
||||
// and the second mode (col/K) is tiled with step 1.
|
||||
//
|
||||
// For our remap, we iterate over the CUTLASS layout indices and for each,
|
||||
// compute which (row, k_group) from the source it corresponds to, then copy.
|
||||
|
||||
template<typename LayoutSF>
|
||||
__global__ void remap_sf_to_cutlass_layout_kernel(
|
||||
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K//16) row-major
|
||||
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS layout
|
||||
__global__ void remap_sf_to_cutlass_kernel(
|
||||
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
|
||||
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout
|
||||
LayoutSF layout_sf, // CuTe layout for dst
|
||||
int MN, int K_sf // Source dimensions
|
||||
) {
|
||||
@@ -130,28 +108,28 @@ __global__ void remap_sf_to_cutlass_layout_kernel(
|
||||
int total = cute::size(layout_sf);
|
||||
if (idx >= total) return;
|
||||
|
||||
// The CUTLASS layout maps idx -> (row_coord, k_coord)
|
||||
// We need to figure out which (row, k_group) this corresponds to
|
||||
// in the simple row-major source layout.
|
||||
// The CUTLASS layout maps linear index -> (m, k) coordinate pair
|
||||
// We need to find which (m, k) this linear index corresponds to
|
||||
// and then read from our simple row-major source.
|
||||
//
|
||||
// For the SfAtom with SFVecSize=16, K-major:
|
||||
// The layout interleaves scale factors in groups of 128 rows
|
||||
// and 64 K-groups (4 "tiles" of 16).
|
||||
// CuTe layouts support crd(idx) which gives the coordinate for an index.
|
||||
// The coordinate is in the logical space of the layout.
|
||||
// For SFA: the layout maps to (M, K) where K is in SF groups
|
||||
// For SFB: the layout maps to (N, K) where K is in SF groups
|
||||
//
|
||||
// Simple approach: use the layout's coordinate function.
|
||||
// The layout maps a linear index to a logical coordinate (m, k).
|
||||
// We can extract (m, k) and then index into our source as m * K_sf + k.
|
||||
// The key: the layout was created by tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2,_1>{})
|
||||
// So the coordinate (c0, c1) corresponds to row c0 and K-group c1 in the original tensor.
|
||||
|
||||
// Get the 2D coordinate from the layout
|
||||
auto coord = layout_sf.get_flat_coord(idx);
|
||||
int m = cute::get<0>(coord); // Row index in the layout's coordinate system
|
||||
int k = cute::get<1>(coord); // K-group index
|
||||
auto coord = layout_sf.get_coord(idx);
|
||||
|
||||
// Extract the (m, k) pair from the coordinate tuple
|
||||
int m = cute::get<0>(cute::flatten(coord)); // Row index
|
||||
int k = cute::get<1>(cute::flatten(coord)); // K-group index
|
||||
|
||||
// Map to source index
|
||||
if (m < MN && k < K_sf) {
|
||||
dst[idx] = src[m * K_sf + k];
|
||||
} else {
|
||||
dst[idx] = cutlass::float_ue4m3_t(0); // Zero padding
|
||||
dst[idx] = cutlass::float_ue4m3_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,33 +156,32 @@ int cutlass_nvfp4_gemm_run(
|
||||
LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
|
||||
// For now, pass scale factors directly (no remap) to test if the kernel works
|
||||
// The CUTLASS layout may expect interleaved data, but let's see what happens
|
||||
// with simple row-major. If TMA reads garbage, we'll get NaN or wrong values
|
||||
// rather than crashes (since we verified small sizes work).
|
||||
//
|
||||
// TODO: implement proper interleaved layout remap
|
||||
|
||||
// Temporary: allocate CUTLASS-layout buffers and do a simple copy
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA;
|
||||
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF;
|
||||
|
||||
// Pass scale factors directly — the layout_SFA/layout_SFB tell CUTLASS
|
||||
// how to index into the data. Our data is (MN, K//16) row-major but
|
||||
// CUTLASS expects its own interleaved layout.
|
||||
// For initial testing, pass directly and check for NaN/wrong values.
|
||||
|
||||
int sfa_size = cute::size(layout_SFA);
|
||||
int sfb_size = cute::size(layout_SFB);
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
// Allocate CUTLASS-layout buffers and remap scales
|
||||
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
|
||||
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
|
||||
|
||||
int block = 256;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, 1},
|
||||
{
|
||||
static_cast<const ArrayElementA*>(A_ptr), stride_A,
|
||||
static_cast<const ArrayElementB*>(B_ptr), stride_B,
|
||||
static_cast<const ElementSF*>(SFA_ptr), layout_SFA,
|
||||
static_cast<const ElementSF*>(SFB_ptr), layout_SFB
|
||||
sfa_cutlass.get(), layout_SFA,
|
||||
sfb_cutlass.get(), layout_SFB
|
||||
},
|
||||
{
|
||||
{ alpha, beta },
|
||||
@@ -227,4 +204,4 @@ int cutlass_nvfp4_gemm_run(
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user