Fix grouped gemms performance issue. (#168)
This commit is contained in:
@@ -15,19 +15,14 @@ template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumS
|
||||
static constexpr uint32_t get_num_1d_blocks_per_group() {
|
||||
// Select the best from candidates
|
||||
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
|
||||
if constexpr (kGemmType == GemmType::MGroupedContiguous or
|
||||
kGemmType == GemmType::MGroupedMasked) {
|
||||
// For grouped GEMMs, let weights always stay in the L2 cache and read activations by once
|
||||
num_best_blocks = kNumSMs;
|
||||
} else {
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = kIsMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
}
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = kIsMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
}
|
||||
|
||||
return num_best_blocks;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user