diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 8ac8310..2324a9b 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -15,19 +15,14 @@ template ::max(); - if constexpr (kGemmType == GemmType::MGroupedContiguous or - kGemmType == GemmType::MGroupedMasked) { - // For grouped GEMMs, let weights always stay in the L2 cache and read activations by once - num_best_blocks = kNumSMs; - } else { - for (const auto& candidate: {8u, 16u}) { - const auto& usage = kIsMulticastOnA ? - candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N - candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M - if (usage < min_usage) - min_usage = usage, num_best_blocks = candidate; - } + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; } + return num_best_blocks; }