fix: SF remap — iterate dest indices, extract logical (m, k_sf) from nested coord

The forward-map approach (src -> layout_sf(m, k)) failed because CuTe's
layout operator requires coordinates matching the nested shape rank, and
passing flat (int, int) to a ((32,4),K) shape triggers Mismatched Ranks.

New approach: iterate over CUTLASS dest indices, use idx2crd to get the
hierarchical coordinate, flatten it, then extract logical (m, k_sf) by
interpreting the flattened sub-coordinates correctly:
  flat[0..2] = (inner_M, sub_M, tile_M) -> m = tile_M*128 + inner_M*4 + sub_M
  flat[3..5] = (inner_K, sub_K, tile_K) -> k_sf = tile_K*4 + sub_K
  (inner_K is within one SF group — same byte, so ignored for k_sf)

Previous bug: get<0> and get<1> of flatten gave (inner_M, sub_M) — both
M sub-indices. K information was never extracted, so only k_group=0 worked.

Dest buffer is zero-initialized so padding slots (where m >= MN or
k_sf >= K_sf) stay zero.
This commit is contained in:
2026-05-14 15:01:47 +00:00
parent 395cc31883
commit a8bd962452

View File

@@ -97,47 +97,57 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
// Scale factor remap kernel using CuTe layout operations
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scale factor remap kernel using CuTe layout operations
/////////////////////////////////////////////////////////////////////////////////////////////////
template<typename LayoutSF>
__global__ void remap_sf_to_cutlass_kernel(
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout
LayoutSF layout_sf, // CuTe layout for dst
int MN, int K_sf, // Source dimensions (in SF groups)
int SFVecSize // Elements per SF group (16 for NVFP4)
const cutlass::float_ue4m3_t* __restrict__ src,
cutlass::float_ue4m3_t* __restrict__ dst,
LayoutSF layout_sf,
int MN, int K_sf
) {
// Iterate over SOURCE indices (row-major) and write to CUTLASS destination.
// The layout maps logical (m, k_elements) -> CUTLASS linear index.
// This is the forward direction, which CuTe handles correctly.
//
// IMPORTANT: The CuTe layout uses ELEMENT-SPACE for K, not group-space.
// So k_group=3 with SFVecSize=16 maps to element k=3*16=48 in the layout.
//
// Previous approach (iterate over CUTLASS idx, reverse-map with idx2crd+flatten)
// was broken: flatten() on the nested CuTe coordinate gives atom sub-indices,
// not logical (m, k). This caused all K-groups > 0 in SFA to map to m*K_sf+0,
// losing K-group information entirely.
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = MN * K_sf;
if (src_idx >= total) return;
int m = src_idx / K_sf;
int k_sf = src_idx % K_sf;
int k_elem = k_sf * SFVecSize; // Convert to element-space for CuTe layout
// Use the CuTe layout to find the destination index for this (m, k_elem)
// The layout has nested shape like ((32, 4, tiles_m), (16, 4, tiles_k)) but
// we can pass a flat (m, k_elem) coordinate and CuTe will project correctly.
// Using layout_sf[m, k_elem] (square brackets) avoids the rank-matching issue
// that make_coord() has with nested shapes.
int dst_idx = layout_sf(m, k_elem);
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = cute::size(layout_sf);
if (dst_idx >= total) return;
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
auto flat = cute::flatten(coord);
constexpr int flat_rank = cute::rank_v<decltype(flat)>;
int m, k_sf_out;
if constexpr (flat_rank == 6) {
// Full nested: (inner_M, sub_M, tile_M, inner_K, sub_K, tile_K)
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int tile_m = cute::get<2>(flat);
// get<3> = inner_k (within one SF group — same byte)
int sub_k = cute::get<4>(flat);
int tile_k = cute::get<5>(flat);
m = tile_m * 128 + inner_m * 4 + sub_m;
k_sf_out = tile_k * 4 + sub_k;
} else if constexpr (flat_rank == 4) {
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int sub_k = cute::get<2>(flat);
int tile_k = cute::get<3>(flat);
m = inner_m * 4 + sub_m;
k_sf_out = tile_k * 4 + sub_k;
} else {
m = cute::get<0>(flat);
k_sf_out = cute::get<1>(flat);
}
if (m < MN && k_sf_out < K_sf) {
dst[dst_idx] = src[m * K_sf + k_sf_out];
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// C API
/////////////////////////////////////////////////////////////////////////////////////////////////
extern "C" {
int cutlass_nvfp4_gemm_run(
const void* A_ptr, const void* SFA_ptr,
const void* B_ptr, const void* SFB_ptr,
@@ -173,13 +183,11 @@ int cutlass_nvfp4_gemm_run(
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
int block = 256;
// Grid size based on SOURCE elements (M*K_sf), not CUTLASS buffer size
int sfa_src_total = M * K_sf;
int sfb_src_total = N * K_sf;
remap_sf_to_cutlass_kernel<<<(sfa_src_total + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize);
remap_sf_to_cutlass_kernel<<<(sfb_src_total + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize);
// Grid size based on DEST (CUTLASS buffer) size since kernel iterates dest
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,