diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 7b8318d..a7371b7 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -194,7 +194,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Case 2: same `block_n`, smaller `block_m` (wasted) success |= block_n == best_block_n and block_m < best_block_m; // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better - success |= block_m != best_block_m and block_n > best_block_n; + // NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case + success |= block_m != best_block_m and block_n > best_block_n + and block_n <= n and block_m <= m; } }