diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ee6e4a4..85e13a9 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,7 +27,7 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } -template = 900)) or defined(__CLION_IDE__) + // Currently, only BLOCK_N size of 160 is classified as a largeBlockTile configuration in our optimization framework. + constexpr bool largeBlockTile = (BLOCK_N == 160); // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + if constexpr(!largeBlockTile) { + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + } // Types using WGMMA = typename FP8MMASelector::type; @@ -146,205 +150,760 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); + if constexpr (largeBlockTile) { + auto scheduler = SchedulerLargeBlockTile(SHAPE_N, grouped_layout); + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + + // Issue TMA A without broadcasting + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + + // Issue TMA B with broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + // To ensure optimal performance, conditional checks must never be placed inside the loop body of kNumInnerStages. + switch (n_block_idx % 4) { + case 0: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + // In performance-critical code paths, we rigorously avoid branching logic and unnecessary instructions. + // Here, we adopt a brute-force approach by explicitly unrolling computations for all possible scenarios. + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 4; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + case 1: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 3; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 3; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + case 2: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 2; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 2; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + case 3: + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + #pragma unroll + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4 / 5 * 1; ++ i) { + final_accum[i * 4 + 0] += scale_0_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_0 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_0 * accum[i * 4 + 3]; + } + #pragma unroll + for (int i = WGMMA::kNumAccum / 4 / 5 * 1; i < WGMMA::kNumAccum / 4; ++ i) { + final_accum[i * 4 + 0] += scale_0_1 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_1_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_1 * accum[i * 4 + 3]; + } + } + + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(SHAPE_M, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); + break; + default: + + break; + } + } + } + } else { + auto scheduler = Scheduler(shape_m, grouped_layout); + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + + // Issue TMA A with broadcasting + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs launch_k_iterations([&](int k_iter, auto type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + for (int s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - // Issue TMA A with broadcasting - auto& full_barrier = *full_barriers[s]; - int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + + // Commit WGMMA instructions + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } } // Wait unaligned cases #pragma unroll for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); } }); - } - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); - } - } - } else { - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); - - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); - auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) - st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) + ); } - }; - - // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; - constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (int s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), scale_a_1 = ld_shared(smem_scales_a[s] + r_1); - - // Commit WGMMA instructions - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - #pragma unroll - for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 + ); } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); - empty_barrier_arrive(s); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); } - }); - - // Write back to shared memory using STSM - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { - SM90_U32x4_STSM_N::copy( - __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), - __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), - __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), - __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16) - ); + __syncwarp(); } - if constexpr (WGMMA::kNumAccum % 8 != 0) { - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), - __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), - smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16 - ); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - __syncwarp(); } } + + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); #endif } -template ; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index b44bf95..b242261 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -665,6 +665,135 @@ struct SM90_64x128x32_F32E4M3E4M3_SS { static constexpr int kNumAccum = M * N / 128; }; + +struct SM90_64x144x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}, " + " %72," + " %73," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 144; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + + +struct SM90_64x160x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}, " + " %80," + " %81," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 160; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + + struct SM90_64x192x32_F32E4M3E4M3_SS { __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, @@ -876,6 +1005,8 @@ struct FP8MMASelector { if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); + if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); + if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 329fbb0..6393713 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -98,6 +98,98 @@ struct Scheduler { return true; } }; +template +struct SchedulerLargeBlockTile { + int current_iter = -1; + uint32_t num_aligned_n_blocks; + + // For normal GEMM + // Maybe not used in the masked grouped GEMM + uint32_t num_blocks; + + // For grouped GEMM + int* grouped_layout; + // Only used for masked layout + uint32_t curr_group_idx, curr_cumsum; + + __device__ __forceinline__ explicit SchedulerLargeBlockTile(const uint32_t shape_n, + int* grouped_layout = nullptr) { + num_aligned_n_blocks = ceil_div(shape_n, BLOCK_N); + if constexpr (kGemmType == GemmType::Normal) { + num_blocks = num_aligned_n_blocks * kNumMBlocks; + } else if (kGemmType == GemmType::GroupedContiguous) { + num_blocks = num_aligned_n_blocks * kNumMBlocks; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::GroupedMasked) { + curr_group_idx = curr_cumsum = 0; + this->grouped_layout = grouped_layout; + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_n_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNumMBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto num_blocks_per_group = num_n_blocks * kNumMBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_m_block_idx = group_idx * kNumMBlocksPerGroup; + auto num_m_blocks_in_group = min(kNumMBlocksPerGroup, kNumMBlocks - first_m_block_idx); + // auto num_m_blocks_in_group = kNumMBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + n_block_idx = in_group_idx / num_m_blocks_in_group; + m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; + // if (threadIdx.x == 256) { + // printf("blockIdx.x: %d group_idx: %d in_group_idx: %d m_block_idx: %d n_block_idx: %d\n", blockIdx.x, group_idx, in_group_idx, m_block_idx, n_block_idx); + // } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx=0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if (kGemmType == GemmType::GroupedContiguous) { + auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); + return offset * shape_dim + block_idx * block_size; + } else if (kGemmType == GemmType::GroupedMasked) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; + + if constexpr (kGemmType == GemmType::GroupedMasked) { + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumMBlocks) + break; + + // Move to check the next group + curr_group_idx ++, curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumMBlocks, m_block_idx, n_block_idx); + } else { + if (next_block_idx >= num_blocks) + return false; + + get_swizzled_block_idx(num_aligned_n_blocks, next_block_idx, m_block_idx, n_block_idx); + } + return true; + } +}; #pragma clang diagnostic pop } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index d97a615..903c1a0 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -10,14 +10,14 @@ template = """ using namespace deep_gemm; // Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; +constexpr auto M = {M}, N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; // Make a templated GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -62,7 +62,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64 if m <= 64 else 128, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + # Current optimizations target large-scale Normal GEMMs for dense models and + # Grouped GEMMs for MoE models (contiguous memory layout), with a potential + # block_n tile size of 160 to enhance data reuse in block tiling. + if m >= 4096 and num_groups == 1: + block_ns = tuple((16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 160)) + else: + block_ns = tuple(range(16, 129, 8)) + fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -77,7 +84,16 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, if best_block_m is None or best_block_n is None: success = True elif num_waves < best_num_waves: - success = True + # The value 0.02 is currently an empirically estimated threshold to + # filter out cases unsuitable for large block tile configurations, + # with optimizations planned for later stages to address excluded scenarios. + if block_n == 160 and \ + (num_waves * block_m * block_n - best_num_waves * best_block_m * best_block_n) / (best_num_waves * best_block_m * best_block_n) < 0.02: + success = True + elif block_n == 160: + success = False + else: + success = True elif num_waves == best_num_waves: # Check last wave utilization util = get_last_wave_util(block_m, block_n) @@ -89,24 +105,44 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + if best_block_n != 160: + for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) + if best_smem_size <= sm90_capacity: + best_num_stages = num_stages + break + else: + # NOTES: This is done to reduce the code footprint after unrolling. + # Additionally, if k does not meet the following conditions, a slight performance penalty will occur. + num_stages = 4 + assert k / 128 % num_stages == 0 best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) - if best_smem_size <= sm90_capacity: - best_num_stages = num_stages - break + assert best_smem_size <= sm90_capacity + best_num_stages = num_stages + assert best_num_stages is not None # Decide the number of TMA multicast best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 + # When using large block tiling, broadcasting B is required to achieve maximum performance gains. + if best_block_n != 160: + if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 + else: + if m >= 4096 and is_tma_multicast_legal(m, best_block_m, 2, num_sms) and num_groups == 1: + best_num_tma_multicast = 2 + # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast - assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + if best_block_n != 160: + assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + else: + assert num_min_sms <= num_sms and is_tma_multicast_legal(m, best_block_m, best_num_tma_multicast, num_min_sms) + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size @@ -162,7 +198,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, + keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, space=(), includes=includes, @@ -176,3 +212,5 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Run the kernel runtime(*args) + # For debug + return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 415fc67..c3b26a7 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -11,14 +11,14 @@ template = """ using namespace deep_gemm; // Templated args from Python JIT call -constexpr auto N = {N}, K = {K}; +constexpr auto M = {M}, N = {N}, K = {K}; constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; // Make a templated grouped GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -91,7 +91,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, space=(), includes=includes, @@ -106,6 +106,8 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Run the kernel runtime(*args) + # For debug + return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -169,7 +171,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', - keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, + keys={'M': m, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, space=(), includes=includes, @@ -184,3 +186,5 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Run the kernel runtime(*args) + # For debug + return num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size