diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 8ce8aa0..d114381 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -169,7 +169,7 @@ struct Scheduler { // For SM90 only // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned - is_peer_cta_alive = kNum1DBlocksPerGroup % kNumMulticast == 0 or // Always aligned on N (constant bypass) + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);