From 2da871e304b1dad11ded860cf71252da9252c2b5 Mon Sep 17 00:00:00 2001 From: zhonghui-J Date: Fri, 22 Aug 2025 17:35:43 +0800 Subject: [PATCH] Fix grouped gemms performance issue. (#168) --- .../include/deep_gemm/common/scheduler.cuh | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 8ac8310..2324a9b 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -15,19 +15,14 @@ template ::max(); - if constexpr (kGemmType == GemmType::MGroupedContiguous or - kGemmType == GemmType::MGroupedMasked) { - // For grouped GEMMs, let weights always stay in the L2 cache and read activations by once - num_best_blocks = kNumSMs; - } else { - for (const auto& candidate: {8u, 16u}) { - const auto& usage = kIsMulticastOnA ? - candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N - candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M - if (usage < min_usage) - min_usage = usage, num_best_blocks = candidate; - } + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; } + return num_best_blocks; }