cleanup: remove printf and diag function from CUDA kernel (build fix)

This commit is contained in:
2026-05-15 20:11:40 +00:00
parent e7c3341317
commit a09b9b53a3

View File

@@ -149,24 +149,8 @@ __global__ void remap_sf_to_cutlass_kernel(
}
if (m < MN && k_sf < K_sf) {
// SFA: source is (MN, K_sf) row-major → src[m * K_sf + k_sf]
// SFB: source is (K_sf, MN) row-major → src[k_sf * MN + m]
dst[dst_idx] = col_major_src ? src[k_sf * MN + m] : src[m * K_sf + k_sf];
}
// Diagnostic: print first 10 coordinate mappings (host-side, only dst_idx 0..9)
if (dst_idx < 10) {
printf("[SF-REMAP] dst=%d rank=%d", dst_idx, R);
if constexpr (R >= 1) printf(" f0=%d", (int)cute::get<0>(flat));
if constexpr (R >= 2) printf(" f1=%d", (int)cute::get<1>(flat));
if constexpr (R >= 3) printf(" f2=%d", (int)cute::get<2>(flat));
if constexpr (R >= 4) printf(" f3=%d", (int)cute::get<3>(flat));
if constexpr (R >= 5) printf(" f4=%d", (int)cute::get<4>(flat));
if constexpr (R >= 6) printf(" f5=%d", (int)cute::get<5>(flat));
if constexpr (R >= 7) printf(" f6=%d", (int)cute::get<6>(flat));
if constexpr (R >= 8) printf(" f7=%d", (int)cute::get<7>(flat));
printf(" -> m=%d k_sf=%d col_major=%d MN=%d K_sf=%d\n", m, k_sf, col_major_src, MN, K_sf);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -345,46 +329,4 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb(
} // extern "C"
// Diagnostic: verify SF coordinate mapping
extern "C" int cutlass_nvfp4_gemm_diag_sf_layout(
int M, int N, int K,
int* out_cosize, int* out_size, int* out_rank,
int* out_shape, int* out_stride, // arrays of 8 ints
int* out_coord_map, // flat (dst_idx -> (m, k_sf)) pairs, max 256 entries
int max_entries
) {
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
*out_cosize = cute::cosize(layout_SFA);
*out_size = cute::size(layout_SFA);
auto shape = layout_SFA.shape();
auto stride = layout_SFA.stride();
*out_rank = cute::rank(shape);
// Write shape and stride
for (int i = 0; i < 8; i++) {
out_shape[i] = (i < *out_rank) ? (int)cute::get<0>(cute::slice_and_offset<i>(shape, stride)) : 0;
out_stride[i] = 0;
}
// Write coordinate mapping
int total = std::min(*out_cosize, max_entries);
for (int i = 0; i < total; i++) {
auto coord = cute::idx2crd(i, layout_SFA.shape(), layout_SFA.stride());
auto flat = cute::flatten(coord);
constexpr int R = cute::rank_v<decltype(flat)>;
int m = 0, k_sf = 0;
if constexpr (R == 8) {
m = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128;
k_sf = cute::get<4>(flat) + cute::get<5>(flat) * 4;
}
out_coord_map[i * 2] = m;
out_coord_map[i * 2 + 1] = k_sf;
}
return total;
}
#endif