debug: idx2crd+flatten approach with printf to determine flat_rank

Going back to the idx2crd approach which compiles and runs.
Added printf for flat_rank, MN, K_sf, and first coordinate extraction.
Handles ranks 2-6 with logical (m, k_sf) extraction.
This will tell us the actual flat_rank and whether our extraction is correct.
This commit is contained in:
2026-05-14 15:34:46 +00:00
parent 2ac3a7d631
commit d2c1c76f5b

View File

@@ -94,16 +94,20 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout
// Scale factor remap: row-major source -> CUTLASS interleaved layout
//
// Iterates over source (m, k_group) indices and uses layout_sf(m, k_elem) to
// compute the CUTLASS destination offset. CuTe's layout operator() decomposes
// flat coordinates into nested sub-coordinates automatically — no idx2crd,
// no flatten, no rank inspection.
// Iterates over CUTLASS dest indices, uses idx2crd to get the hierarchical coordinate,
// then extracts logical (m, k_sf) from the flattened result.
//
// K is in element-space for the layout: k_group * SFVecSize.
// Iterating groups (not every element) is efficient since all 16 elements
// within a group share one SF byte.
// The key challenge: the flattened coordinate from idx2crd has nested structure.
// For SFA with Step<_2,_1> tiling, the layout shape is:
// ((32, 4, n_m_tiles), (16, 4, n_k_tiles))
// Flattening gives: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k)
// where inner_m in [0,32), sub_m in [0,4), tile_m in [0, n_m_tiles)
// inner_k in [0,16) (within one SF group), sub_k in [0,4), tile_k in [0, n_k_tiles)
//
// Logical m = tile_m * 128 + inner_m * 4 + sub_m
// Logical k_sf = tile_k * 4 + sub_k (inner_k is within one SF group — same byte)
/////////////////////////////////////////////////////////////////////////////////////////////////
template<typename LayoutSF>
@@ -111,26 +115,75 @@ __global__ void remap_sf_to_cutlass_kernel(
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized)
LayoutSF layout_sf, // CuTe layout for dst
int MN, int K_sf, // Source dimensions (in SF groups)
int SFVecSize // Elements per SF group (16 for NVFP4)
int MN, int K_sf // Source dimensions (in SF groups)
) {
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = MN * K_sf;
if (src_idx >= total) return;
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = cute::size(layout_sf);
if (dst_idx >= total) return;
int m = src_idx / K_sf;
int k_group = src_idx % K_sf;
int k_elem = k_group * SFVecSize;
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
auto flat = cute::flatten(coord);
// Construct a coordinate matching the layout's top-level mode structure.
// The layout shape from tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2, _1>)
// has two top-level modes:
// Mode 0 (M): shape (32, 4) — the SfAtom's M sub-structure
// Mode 1 (K): int — the K dimension
// For mode 0, we decompose m into (inner_m, sub_m) = (m / 4, m % 4)
// For mode 1, we pass k_elem directly.
auto coord = cute::make_tuple(cute::make_tuple(m / 4, m % 4), k_elem);
dst[layout_sf(coord)] = src[src_idx];
constexpr int R = cute::rank_v<decltype(flat)>;
// Debug: print rank once (only thread 0)
if (dst_idx == 0) {
printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total);
}
int m = 0, k_sf = 0;
if constexpr (R == 6) {
// Full 6: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k)
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int tile_m = cute::get<2>(flat);
// get<3> = inner_k (within SF group, ignored for k_sf)
int sub_k = cute::get<4>(flat);
int tile_k = cute::get<5>(flat);
m = tile_m * 128 + inner_m * 4 + sub_m;
k_sf = tile_k * 4 + sub_k;
} else if constexpr (R == 5) {
// 5: maybe (inner_m, sub_m, tile_m, sub_k, tile_k) or similar
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int tile_m = cute::get<2>(flat);
int sub_k = cute::get<3>(flat);
int tile_k = cute::get<4>(flat);
m = tile_m * 128 + inner_m * 4 + sub_m;
k_sf = tile_k * 4 + sub_k;
} else if constexpr (R == 4) {
// 4: (inner_m, sub_m, sub_k, tile_k) — small M fits in 1 tile
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int sub_k = cute::get<2>(flat);
int tile_k = cute::get<3>(flat);
m = inner_m * 4 + sub_m;
k_sf = tile_k * 4 + sub_k;
} else if constexpr (R == 3) {
// 3: maybe (inner_m, sub_m, k_combined)
int inner_m = cute::get<0>(flat);
int sub_m = cute::get<1>(flat);
int k_comb = cute::get<2>(flat);
m = inner_m * 4 + sub_m;
k_sf = k_comb;
} else if constexpr (R == 2) {
// 2: flat (m, k) — no nesting
m = cute::get<0>(flat);
k_sf = cute::get<1>(flat);
} else {
// Fallback: index 0 and 1
m = 0; k_sf = 0;
if (dst_idx == 0) printf("[remap] UNEXPECTED flat_rank=%d\n", R);
}
if (dst_idx == 0) {
printf("[remap] first coord: m=%d k_sf=%d (src_idx would be %d)\n", m, k_sf, m * K_sf + k_sf);
}
if (m < MN && k_sf < K_sf) {
dst[dst_idx] = src[m * K_sf + k_sf];
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -164,19 +217,16 @@ int cutlass_nvfp4_gemm_run(
int sfb_size = cute::size(layout_SFB);
int K_sf = K / InputSFVectorSize;
// Allocate CUTLASS-layout buffers, zero-init, and remap scales
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream);
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
int block = 256;
int sfa_src = M * K_sf;
int sfb_src = N * K_sf;
remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize);
remap_sf_to_cutlass_kernel<<<(sfb_src + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize);
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,