Fix multicast bug and optimize masked GEMM (#193)

* Fix multicast bug and profile masked GEMM

* Updates and lint

---------

Co-authored-by: Kuai Yu <yukuai@deepseek.com>
Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
yukuai26
2025-09-12 17:12:27 +08:00
committed by GitHub
parent ea9c5d9270
commit 79f48ee15a
3 changed files with 10 additions and 5 deletions

View File

@@ -152,8 +152,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
// Select M/N block sizes
// TODO: support `% 16 == 8` block size on SM90
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
auto block_ms = std::vector{64, 128, 256};
if (gemm_type == GemmType::MGroupedContiguous)
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};
std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
block_ns.push_back(i);
@@ -214,7 +217,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
MulticastConfig best_multicast_config = {1, true};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, m, n, best_block_m, best_block_n, num_sms);
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
std::swap(order[0], order[1]);

View File

@@ -91,8 +91,8 @@ struct SM100ArchSpec {
const int& num_sms) {
// TODO: support other layouts
return {
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
false,
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
};
}

View File

@@ -71,7 +71,9 @@ struct SM90ArchSpec {
const int& num_sms) {
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
is_multicast_legal(m, block_m, 2, num_sms, false)
and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true))
};
}