cleanup: remove printf and diag function from CUDA kernel (build fix)
This commit is contained in:
@@ -149,24 +149,8 @@ __global__ void remap_sf_to_cutlass_kernel(
|
||||
}
|
||||
|
||||
if (m < MN && k_sf < K_sf) {
|
||||
// SFA: source is (MN, K_sf) row-major → src[m * K_sf + k_sf]
|
||||
// SFB: source is (K_sf, MN) row-major → src[k_sf * MN + m]
|
||||
dst[dst_idx] = col_major_src ? src[k_sf * MN + m] : src[m * K_sf + k_sf];
|
||||
}
|
||||
|
||||
// Diagnostic: print first 10 coordinate mappings (host-side, only dst_idx 0..9)
|
||||
if (dst_idx < 10) {
|
||||
printf("[SF-REMAP] dst=%d rank=%d", dst_idx, R);
|
||||
if constexpr (R >= 1) printf(" f0=%d", (int)cute::get<0>(flat));
|
||||
if constexpr (R >= 2) printf(" f1=%d", (int)cute::get<1>(flat));
|
||||
if constexpr (R >= 3) printf(" f2=%d", (int)cute::get<2>(flat));
|
||||
if constexpr (R >= 4) printf(" f3=%d", (int)cute::get<3>(flat));
|
||||
if constexpr (R >= 5) printf(" f4=%d", (int)cute::get<4>(flat));
|
||||
if constexpr (R >= 6) printf(" f5=%d", (int)cute::get<5>(flat));
|
||||
if constexpr (R >= 7) printf(" f6=%d", (int)cute::get<6>(flat));
|
||||
if constexpr (R >= 8) printf(" f7=%d", (int)cute::get<7>(flat));
|
||||
printf(" -> m=%d k_sf=%d col_major=%d MN=%d K_sf=%d\n", m, k_sf, col_major_src, MN, K_sf);
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -345,46 +329,4 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb(
|
||||
|
||||
} // extern "C"
|
||||
|
||||
// Diagnostic: verify SF coordinate mapping
|
||||
extern "C" int cutlass_nvfp4_gemm_diag_sf_layout(
|
||||
int M, int N, int K,
|
||||
int* out_cosize, int* out_size, int* out_rank,
|
||||
int* out_shape, int* out_stride, // arrays of 8 ints
|
||||
int* out_coord_map, // flat (dst_idx -> (m, k_sf)) pairs, max 256 entries
|
||||
int max_entries
|
||||
) {
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
LayoutSFA layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
|
||||
|
||||
*out_cosize = cute::cosize(layout_SFA);
|
||||
*out_size = cute::size(layout_SFA);
|
||||
|
||||
auto shape = layout_SFA.shape();
|
||||
auto stride = layout_SFA.stride();
|
||||
*out_rank = cute::rank(shape);
|
||||
|
||||
// Write shape and stride
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out_shape[i] = (i < *out_rank) ? (int)cute::get<0>(cute::slice_and_offset<i>(shape, stride)) : 0;
|
||||
out_stride[i] = 0;
|
||||
}
|
||||
|
||||
// Write coordinate mapping
|
||||
int total = std::min(*out_cosize, max_entries);
|
||||
for (int i = 0; i < total; i++) {
|
||||
auto coord = cute::idx2crd(i, layout_SFA.shape(), layout_SFA.stride());
|
||||
auto flat = cute::flatten(coord);
|
||||
constexpr int R = cute::rank_v<decltype(flat)>;
|
||||
int m = 0, k_sf = 0;
|
||||
if constexpr (R == 8) {
|
||||
m = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128;
|
||||
k_sf = cute::get<4>(flat) + cute::get<5>(flat) * 4;
|
||||
}
|
||||
out_coord_map[i * 2] = m;
|
||||
out_coord_map[i * 2 + 1] = k_sf;
|
||||
}
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user