Fix multicast bug and optimize masked GEMM (#193)
* Fix multicast bug and profile masked GEMM * Updates and lint --------- Co-authored-by: Kuai Yu <yukuai@deepseek.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -152,8 +152,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
|
||||
// Select M/N block sizes
|
||||
// TODO: support `% 16 == 8` block size on SM90
|
||||
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
|
||||
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
|
||||
auto block_ms = std::vector{64, 128, 256};
|
||||
if (gemm_type == GemmType::MGroupedContiguous)
|
||||
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
|
||||
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
|
||||
block_ms = std::vector{64, 128};
|
||||
std::vector<int> block_ns;
|
||||
for (int i = 16; i <= 256; i += 16)
|
||||
block_ns.push_back(i);
|
||||
@@ -214,7 +217,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
MulticastConfig best_multicast_config = {1, true};
|
||||
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
gemm_type, m, n, best_block_m, best_block_n, num_sms);
|
||||
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
|
||||
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
|
||||
bool order[2] = {false, true};
|
||||
if (best_block_m > best_block_n)
|
||||
std::swap(order[0], order[1]);
|
||||
|
||||
@@ -91,8 +91,8 @@ struct SM100ArchSpec {
|
||||
const int& num_sms) {
|
||||
// TODO: support other layouts
|
||||
return {
|
||||
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
|
||||
false,
|
||||
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +71,9 @@ struct SM90ArchSpec {
|
||||
const int& num_sms) {
|
||||
return {
|
||||
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
|
||||
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
|
||||
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
|
||||
is_multicast_legal(m, block_m, 2, num_sms, false)
|
||||
and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true))
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user