debug: idx2crd+flatten approach with printf to determine flat_rank
Going back to the idx2crd approach which compiles and runs. Added printf for flat_rank, MN, K_sf, and first coordinate extraction. Handles ranks 2-6 with logical (m, k_sf) extraction. This will tell us the actual flat_rank and whether our extraction is correct.
This commit is contained in:
@@ -94,16 +94,20 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Scale factor remap kernel: row-major source -> CUTLASS interleaved layout
|
||||
// Scale factor remap: row-major source -> CUTLASS interleaved layout
|
||||
//
|
||||
// Iterates over source (m, k_group) indices and uses layout_sf(m, k_elem) to
|
||||
// compute the CUTLASS destination offset. CuTe's layout operator() decomposes
|
||||
// flat coordinates into nested sub-coordinates automatically — no idx2crd,
|
||||
// no flatten, no rank inspection.
|
||||
// Iterates over CUTLASS dest indices, uses idx2crd to get the hierarchical coordinate,
|
||||
// then extracts logical (m, k_sf) from the flattened result.
|
||||
//
|
||||
// K is in element-space for the layout: k_group * SFVecSize.
|
||||
// Iterating groups (not every element) is efficient since all 16 elements
|
||||
// within a group share one SF byte.
|
||||
// The key challenge: the flattened coordinate from idx2crd has nested structure.
|
||||
// For SFA with Step<_2,_1> tiling, the layout shape is:
|
||||
// ((32, 4, n_m_tiles), (16, 4, n_k_tiles))
|
||||
// Flattening gives: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k)
|
||||
// where inner_m in [0,32), sub_m in [0,4), tile_m in [0, n_m_tiles)
|
||||
// inner_k in [0,16) (within one SF group), sub_k in [0,4), tile_k in [0, n_k_tiles)
|
||||
//
|
||||
// Logical m = tile_m * 128 + inner_m * 4 + sub_m
|
||||
// Logical k_sf = tile_k * 4 + sub_k (inner_k is within one SF group — same byte)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename LayoutSF>
|
||||
@@ -111,26 +115,75 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
|
||||
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized)
|
||||
LayoutSF layout_sf, // CuTe layout for dst
|
||||
int MN, int K_sf, // Source dimensions (in SF groups)
|
||||
int SFVecSize // Elements per SF group (16 for NVFP4)
|
||||
int MN, int K_sf // Source dimensions (in SF groups)
|
||||
) {
|
||||
int src_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = MN * K_sf;
|
||||
if (src_idx >= total) return;
|
||||
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = cute::size(layout_sf);
|
||||
if (dst_idx >= total) return;
|
||||
|
||||
int m = src_idx / K_sf;
|
||||
int k_group = src_idx % K_sf;
|
||||
int k_elem = k_group * SFVecSize;
|
||||
auto coord = cute::idx2crd(dst_idx, layout_sf.shape(), layout_sf.stride());
|
||||
auto flat = cute::flatten(coord);
|
||||
|
||||
// Construct a coordinate matching the layout's top-level mode structure.
|
||||
// The layout shape from tile_to_shape(SfAtom{}, make_shape(MN, K), Step<_2, _1>)
|
||||
// has two top-level modes:
|
||||
// Mode 0 (M): shape (32, 4) — the SfAtom's M sub-structure
|
||||
// Mode 1 (K): int — the K dimension
|
||||
// For mode 0, we decompose m into (inner_m, sub_m) = (m / 4, m % 4)
|
||||
// For mode 1, we pass k_elem directly.
|
||||
auto coord = cute::make_tuple(cute::make_tuple(m / 4, m % 4), k_elem);
|
||||
dst[layout_sf(coord)] = src[src_idx];
|
||||
constexpr int R = cute::rank_v<decltype(flat)>;
|
||||
|
||||
// Debug: print rank once (only thread 0)
|
||||
if (dst_idx == 0) {
|
||||
printf("[remap] flat_rank=%d MN=%d K_sf=%d total=%d\n", R, MN, K_sf, total);
|
||||
}
|
||||
|
||||
int m = 0, k_sf = 0;
|
||||
|
||||
if constexpr (R == 6) {
|
||||
// Full 6: (inner_m, sub_m, tile_m, inner_k, sub_k, tile_k)
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int tile_m = cute::get<2>(flat);
|
||||
// get<3> = inner_k (within SF group, ignored for k_sf)
|
||||
int sub_k = cute::get<4>(flat);
|
||||
int tile_k = cute::get<5>(flat);
|
||||
m = tile_m * 128 + inner_m * 4 + sub_m;
|
||||
k_sf = tile_k * 4 + sub_k;
|
||||
} else if constexpr (R == 5) {
|
||||
// 5: maybe (inner_m, sub_m, tile_m, sub_k, tile_k) or similar
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int tile_m = cute::get<2>(flat);
|
||||
int sub_k = cute::get<3>(flat);
|
||||
int tile_k = cute::get<4>(flat);
|
||||
m = tile_m * 128 + inner_m * 4 + sub_m;
|
||||
k_sf = tile_k * 4 + sub_k;
|
||||
} else if constexpr (R == 4) {
|
||||
// 4: (inner_m, sub_m, sub_k, tile_k) — small M fits in 1 tile
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int sub_k = cute::get<2>(flat);
|
||||
int tile_k = cute::get<3>(flat);
|
||||
m = inner_m * 4 + sub_m;
|
||||
k_sf = tile_k * 4 + sub_k;
|
||||
} else if constexpr (R == 3) {
|
||||
// 3: maybe (inner_m, sub_m, k_combined)
|
||||
int inner_m = cute::get<0>(flat);
|
||||
int sub_m = cute::get<1>(flat);
|
||||
int k_comb = cute::get<2>(flat);
|
||||
m = inner_m * 4 + sub_m;
|
||||
k_sf = k_comb;
|
||||
} else if constexpr (R == 2) {
|
||||
// 2: flat (m, k) — no nesting
|
||||
m = cute::get<0>(flat);
|
||||
k_sf = cute::get<1>(flat);
|
||||
} else {
|
||||
// Fallback: index 0 and 1
|
||||
m = 0; k_sf = 0;
|
||||
if (dst_idx == 0) printf("[remap] UNEXPECTED flat_rank=%d\n", R);
|
||||
}
|
||||
|
||||
if (dst_idx == 0) {
|
||||
printf("[remap] first coord: m=%d k_sf=%d (src_idx would be %d)\n", m, k_sf, m * K_sf + k_sf);
|
||||
}
|
||||
|
||||
if (m < MN && k_sf < K_sf) {
|
||||
dst[dst_idx] = src[m * K_sf + k_sf];
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -164,19 +217,16 @@ int cutlass_nvfp4_gemm_run(
|
||||
int sfb_size = cute::size(layout_SFB);
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
// Allocate CUTLASS-layout buffers, zero-init, and remap scales
|
||||
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
|
||||
cutlass::device_memory::allocation<ElementSF> sfb_cutlass(sfb_size);
|
||||
cudaMemsetAsync(sfa_cutlass.get(), 0, sfa_size * sizeof(ElementSF), stream);
|
||||
cudaMemsetAsync(sfb_cutlass.get(), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
|
||||
int block = 256;
|
||||
int sfa_src = M * K_sf;
|
||||
int sfb_src = N * K_sf;
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_src + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, InputSFVectorSize);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_src + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, InputSFVectorSize);
|
||||
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
|
||||
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
|
||||
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
|
||||
|
||||
typename Gemm::Arguments arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
Reference in New Issue
Block a user