fix: SF remap — iterate dest indices, extract logical (m, k_sf) from nested coord
The forward-map approach (src -> layout_sf(m, k)) failed because CuTe's layout operator requires coordinates matching the nested shape rank, and passing flat (int, int) to a ((32,4),K) shape triggers Mismatched Ranks. New approach: iterate over CUTLASS dest indices, use idx2crd to get the hierarchical coordinate, flatten it, then extract logical (m, k_sf) by interpreting the flattened sub-coordinates correctly: flat[0..2] = (inner_M, sub_M, tile_M) -> m = tile_M*128 + inner_M*4 + sub_M flat[3..5] = (inner_K, sub_K, tile_K) -> k_sf = tile_K*4 + sub_K (inner_K is within one SF group — same byte, so ignored for k_sf) Previous bug: get<0> and get<1> of flatten gave (inner_M, sub_M) — both M sub-indices. K information was never extracted, so only k_group=0 worked. Dest buffer is zero-initialized so padding slots (where m >= MN or k_sf >= K_sf) stay zero.
This commit is contained in:
@@ -97,47 +97,57 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
// Scale factor remap kernel using CuTe layout operations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Scale factor remap kernel using CuTe layout operations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename LayoutSF>
|
||||
__global__ void remap_sf_to_cutlass_kernel(
|
||||
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
|
||||
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout
|
||||
LayoutSF layout_sf, // CuTe layout for dst
|
||||
int MN, int K_sf, // Source dimensions (in SF groups)
|
||||
int SFVecSize // Elements per SF group (16 for NVFP4)
|
||||
const cutlass::float_ue4m3_t* __restrict__ src,
|
||||
cutlass::float_ue4m3_t* __restrict__ dst,
|
||||
LayoutSF layout_sf,
|
||||
int MN, int K_sf
|
||||
) {
|
||||
// Iterate over SOURCE indices (row-major) and write to CUTLASS destination.
|
||||
// The layout maps logical (m, k_elements) -> CUTLASS linear index.
|
||||
// This is the forward direction, which CuTe handles correctly.
|
||||
//
|
||||
// IMPORTANT: The CuTe layout uses ELEMENT-SPACE for K, not group-space.
|
||||
// So k_group=3 with SFVecSize=16 maps to element k=3*16=48 in the layout.
|
||||
//
|
||||
// Previous approach (iterate over CUTLASS idx, reverse-map with idx2crd+flatten)
|
||||
// was broken: flatten() on the nested CuTe coordinate gives atom sub-indices,
|
||||
// not logical (m, k). This caused all K-groups > 0 in SFA to map to m*K_sf+0,
|
||||
// losing K-group information entirely.
|
||||
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = MN * K_sf;
|
||||
if (src_idx >= total) return;
|
||||
|
||||
int m = src_idx / K_sf;
|
||||
int k_sf = src_idx % K_sf;
|
||||
int k_elem = k_sf * SFVecSize; // Convert to element-space for CuTe layout
|
||||
|
||||
// Use the CuTe layout to find the destination index for this (m, k_elem)
|
||||
// The layout has nested shape like ((32, 4, tiles_m), (16, 4, tiles_k)) but
|
||||
// we can pass a flat (m, k_elem) coordinate and CuTe will project correctly.
|
||||
// Using layout_sf[m, k_elem] (square brackets) avoids the rank-matching issue
|
||||
// that make_coord() has with nested shapes.
|
||||
int dst_idx = layout_sf(m, k_elem);
|
||||
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = cute::size(layout_sf);
|
||||
if (dst_idx >= total) return;
|
||||
|
||||
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
|
||||
auto flat = cute::flatten(coord);
|
||||
|
||||
constexpr int flat_rank = cute::rank_v<decltype(flat)>;
|
||||
|
||||
int m, k_sf_out;
|
||||
|
||||
if constexpr (flat_rank == 6) {
|
||||
// Full nested: (inner_M, sub_M, tile_M, inner_K, sub_K, tile_K)
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int tile_m = cute::get<2>(flat);
|
||||
// get<3> = inner_k (within one SF group — same byte)
|
||||
int sub_k = cute::get<4>(flat);
|
||||
int tile_k = cute::get<5>(flat);
|
||||
|
||||
m = tile_m * 128 + inner_m * 4 + sub_m;
|
||||
k_sf_out = tile_k * 4 + sub_k;
|
||||
} else if constexpr (flat_rank == 4) {
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int sub_k = cute::get<2>(flat);
|
||||
int tile_k = cute::get<3>(flat);
|
||||
|
||||
m = inner_m * 4 + sub_m;
|
||||
k_sf_out = tile_k * 4 + sub_k;
|
||||
} else {
|
||||
m = cute::get<0>(flat);
|
||||
k_sf_out = cute::get<1>(flat);
|
||||
}
|
||||
|
||||
if (m < MN && k_sf_out < K_sf) {
|
||||
dst[dst_idx] = src[m * K_sf + k_sf_out];
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// C API
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
extern "C" {
|
||||
|
||||
int cutlass_nvfp4_gemm_run(
|
||||
const void* A_ptr, const void* SFA_ptr,
|
||||
const void* B_ptr, const void* SFB_ptr,
|
||||
@@ -173,13 +183,11 @@ int cutlass_nvfp4_gemm_run(
|
||||
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
// Grid size based on SOURCE elements (M*K_sf), not CUTLASS buffer size
|
||||
int sfa_src_total = M * K_sf;
|
||||
int sfb_src_total = N * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize);
|
||||
// Grid size based on DEST (CUTLASS buffer) size since kernel iterates dest
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
Reference in New Issue
Block a user