diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 3ed4d2a..681e654 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -152,8 +152,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Select M/N block sizes // TODO: support `% 16 == 8` block size on SM90 - const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ? - std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256}; + auto block_ms = std::vector{64, 128, 256}; + if (gemm_type == GemmType::MGroupedContiguous) + block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; + if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance + block_ms = std::vector{64, 128}; std::vector block_ns; for (int i = 16; i <= 256; i += 16) block_ns.push_back(i); @@ -214,7 +217,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k MulticastConfig best_multicast_config = {1, true}; const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( gemm_type, m, n, best_block_m, best_block_n, num_sms); - const bool is_legal[2] = {is_legal_on_a, is_legal_on_b}; + const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; bool order[2] = {false, true}; if (best_block_m > best_block_n) std::swap(order[0], order[1]); diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 4e58289..0679cad 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -91,8 +91,8 @@ struct SM100ArchSpec { const int& num_sms) { // TODO: support other layouts return { - is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), false, + is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), }; } diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 16ca018..58faecf 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -71,7 +71,9 @@ struct SM90ArchSpec { const int& num_sms) { return { is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), - is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked, + // For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even + is_multicast_legal(m, block_m, 2, num_sms, false) + and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true)) }; }