Fix grouped gemms performance issue. (#168)

This commit is contained in:
zhonghui-J
2025-08-22 17:35:43 +08:00
committed by GitHub
parent e38c2e3103
commit 2da871e304

View File

@@ -15,19 +15,14 @@ template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumS
static constexpr uint32_t get_num_1d_blocks_per_group() {
// Select the best from candidates
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
if constexpr (kGemmType == GemmType::MGroupedContiguous or
kGemmType == GemmType::MGroupedMasked) {
// For grouped GEMMs, let weights always stay in the L2 cache and read activations by once
num_best_blocks = kNumSMs;
} else {
for (const auto& candidate: {8u, 16u}) {
const auto& usage = kIsMulticastOnA ?
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
if (usage < min_usage)
min_usage = usage, num_best_blocks = candidate;
}
for (const auto& candidate: {8u, 16u}) {
const auto& usage = kIsMulticastOnA ?
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
if (usage < min_usage)
min_usage = usage, num_best_blocks = candidate;
}
return num_best_blocks;
}