diff --git a/README.md b/README.md index d3cbbdf..d1b0817 100644 --- a/README.md +++ b/README.md @@ -28,21 +28,21 @@ DeepGEMM does not behave very well on some shapes, optimization PRs are welcomed | 128 | 7168 | 16384 | 645 TFLOPS | 2604 GB/s | 1.4x | | 128 | 4096 | 7168 | 533 TFLOPS | 2221 GB/s | 2.0x | | 128 | 7168 | 2048 | 510 TFLOPS | 2277 GB/s | 1.7x | -| 4096 | 2112 | 7168 | 1058 TFLOPS | 527 GB/s | 1.1x | -| 4096 | 24576 | 1536 | 990 TFLOPS | 786 GB/s | 1.0x | -| 4096 | 32768 | 512 | 590 TFLOPS | 1232 GB/s | 1.0x | -| 4096 | 7168 | 16384 | 1358 TFLOPS | 343 GB/s | 1.2x | -| 4096 | 4096 | 7168 | 1304 TFLOPS | 500 GB/s | 1.1x | -| 4096 | 7168 | 2048 | 1025 TFLOPS | 697 GB/s | 1.1x | +| 4096 | 2112 | 7168 | 1009 TFLOPS | 503 GB/s | 1.1x | +| 4096 | 24576 | 1536 | 1125 TFLOPS | 893 GB/s | 1.1x | +| 4096 | 32768 | 512 | 751 TFLOPS | 1569 GB/s | 1.1x | +| 4096 | 7168 | 16384 | 1426 TFLOPS | 361 GB/s | 1.3x | +| 4096 | 4096 | 7168 | 1265 TFLOPS | 485 GB/s | 1.2x | +| 4096 | 7168 | 2048 | 1168 TFLOPS | 794 GB/s | 1.2x | ### Grouped GEMMs for MoE models (contiguous layout) | #Groups | M per group | N | K | Computation | Memory bandwidth | Speedup | |:-------:|:-----------:|:----:|:----:|:-----------:|:----------------:|:-------:| -| 4 | 8192 | 4096 | 7168 | 1297 TFLOPS | 418 GB/s | 1.2x | -| 4 | 8192 | 7168 | 2048 | 1099 TFLOPS | 681 GB/s | 1.2x | -| 8 | 4096 | 4096 | 7168 | 1288 TFLOPS | 494 GB/s | 1.2x | -| 8 | 4096 | 7168 | 2048 | 1093 TFLOPS | 743 GB/s | 1.1x | +| 4 | 8192 | 4096 | 7168 | 1346 TFLOPS | 434 GB/s | 1.3x | +| 4 | 8192 | 7168 | 2048 | 1214 TFLOPS | 752 GB/s | 1.3x | +| 8 | 4096 | 4096 | 7168 | 1346 TFLOPS | 516 GB/s | 1.3x | +| 8 | 4096 | 7168 | 2048 | 1214 TFLOPS | 826 GB/s | 1.2x | ### Grouped GEMMs for MoE models (masked layout) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ee6e4a4..d9ab480 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -27,11 +27,22 @@ __device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; } +template +__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, int num_former_iters) { + if (num_former_iters == kNumFormerIters) { + inner_launch_k_iterations(func, cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); +} + template __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, @@ -43,7 +54,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types using WGMMA = typename FP8MMASelector::type; @@ -129,15 +140,23 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // For pipeline unrolling struct DivisibleK {}; struct NotDivisibleK {}; - auto launch_k_iterations = [](const auto& func) { - if constexpr (SHAPE_K % kFullKOfAllStages == 0) { - for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(kNumIterations - 1, NotDivisibleK{}); - } + auto launch_k_iterations = [](const auto& func, int num_former_iters) { + constexpr bool kShouldOptimize = BLOCK_K / gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr int kGap = gcd(BLOCK_K, BLOCK_N) / 8; + constexpr int kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([](const auto& func, auto num_former_iters_type) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, num_former_iters_type); + func(kNumIterations - 1, NotDivisibleK{}, num_former_iters_type); + } + }, func, kShouldOptimize ? num_former_iters : 0); }; // Register reconfigurations @@ -146,7 +165,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, grouped_layout); + auto scheduler = Scheduler(shape_m, grouped_layout); if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data @@ -156,11 +175,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if (threadIdx.x == kNumMathThreads) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto _) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + // Assign TMA multicast number into A and B + constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1; + constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast; + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant #pragma unroll @@ -168,18 +191,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait consumer release empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); - // Issue TMA A with broadcasting + // Issue TMA A auto& full_barrier = *full_barriers[s]; int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); - // Issue TMA B without broadcasting - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } @@ -189,7 +212,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); full_barriers[s]->arrive(); } - }); + }, 0); } // To safely deconstruct distributed shared barriers, we need another round of empty waits @@ -242,7 +265,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, }; // Launch MMAs - launch_k_iterations([&](int k_iter, auto type) { + launch_k_iterations([&](int k_iter, auto type, auto num_former_iters_type) { constexpr bool kHasDivisibleStages = std::is_same_v; constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -288,8 +311,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, float scale_0_1, scale_1_1; if constexpr (not kMustUseUniformedScaleB) scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant bool predicate = kMustUseUniformedScaleB or i < num_former_iters; final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; @@ -304,7 +329,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); empty_barrier_arrive(s); } - }); + }, num_former_iters); // Write back to shared memory using STSM DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); @@ -347,7 +372,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, template class Gemm { private: @@ -369,7 +394,7 @@ public: constexpr uint32_t kNumMathThreadsPerGroup = 128; auto kernel = fp8_gemm_kernel; + kNumTMAMulticast, kIsTMAMulticastOnA, kGemmType>; DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); // Cluster launch diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index b44bf95..b242261 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -665,6 +665,135 @@ struct SM90_64x128x32_F32E4M3E4M3_SS { static constexpr int kNumAccum = M * N / 128; }; + +struct SM90_64x144x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}, " + " %72," + " %73," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 144; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + + +struct SM90_64x160x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, + float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, + float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, + float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, + float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, + float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, + float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, + float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + bool scale_d) { + asm volatile("{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}, " + " %80," + " %81," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + wgmma(desc_a, desc_b, + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], + d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], + d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], + d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], + d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], + d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], + d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], + d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 160; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + + struct SM90_64x192x32_F32E4M3E4M3_SS { __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, @@ -876,6 +1005,8 @@ struct FP8MMASelector { if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); + if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); + if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 329fbb0..6e3cb52 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -12,9 +12,10 @@ enum class GemmType { #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template + uint32_t kNum1DBlocksPerGroup = 16> struct Scheduler { int current_iter = -1; uint32_t num_aligned_m_blocks; @@ -43,16 +44,27 @@ struct Scheduler { } __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { - DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages - auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; - auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = in_group_idx / num_n_blocks_in_group; - n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + // TODO: unify these 2 branches + if constexpr (kIsTMAMulticastOnA) { + auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + } else { + auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup; + auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; + n_block_idx = in_group_idx / num_m_blocks_in_group; + } } template @@ -98,6 +110,7 @@ struct Scheduler { return true; } }; + #pragma clang diagnostic pop } // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/utils.cuh b/deep_gemm/include/deep_gemm/utils.cuh index 0005907..fe2c016 100644 --- a/deep_gemm/include/deep_gemm/utils.cuh +++ b/deep_gemm/include/deep_gemm/utils.cuh @@ -46,3 +46,8 @@ template __device__ __host__ constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; } + +template +__device__ __host__ constexpr T gcd(T a, T b) { + return b == 0 ? a : gcd(b, a % b); +} diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 0f099d8..aad8939 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -101,7 +101,7 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime: '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases '--diag-suppress=177,174,940'] - cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi'] + cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}'] include_dirs = [get_jit_include_dir()] diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index cdca4c4..ead37f5 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -67,7 +67,10 @@ def cpp_format(template: str, keys: Dict[str, Any]) -> str: # We don't use `str.format` because it's not safe for C++ {} braces new_template = copy.deepcopy(template) for key, value in keys.items(): - new_template = new_template.replace(f'{{{key}}}', f'{value}') + value_str = str(value) + if isinstance(value, bool): + value_str = value_str.lower() + new_template = new_template.replace(f'{{{key}}}', f'{value_str}') return new_template diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index d97a615..65b44ff 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -1,3 +1,4 @@ +import math import torch from typing import Tuple @@ -15,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -31,10 +33,10 @@ GemmType::run(out, rhs_scales, nullptr, """ -def is_tma_multicast_legal(n: int, block_n: int, num_tma_multicast: int, num_sms: int) -> bool: +def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: if num_tma_multicast == 1: return True - return (n % (block_n * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 + return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int: @@ -56,13 +58,13 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, int, int]: + is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: # TODO: for some cases, smaller M block is better, add them into tuning space block_ms = (64 if m <= 64 else 128, ) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + block_ns = tuple(range(16, 129, 8)) + (144, 160, ) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) @@ -89,26 +91,38 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + stage_candidates = (8, 7, 6, 5, 4) + if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: + # Unrolling both stages and `num_former_iters` will cause large code size + stage_candidates = (4, ) + for num_stages in stage_candidates: best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) if best_smem_size <= sm90_capacity: best_num_stages = num_stages break assert best_num_stages is not None - # Decide the number of TMA multicast - best_num_tma_multicast = 1 - if m >= 1024 and is_tma_multicast_legal(n, best_block_n, 2, num_sms) and num_groups == 1: - best_num_tma_multicast = 2 + # Decide the number of TMA multicast and whether broadcast on A + best_tma_multicast_config = (1, True) + + # Try to multicast on the larger block side first + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms), + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if m >= 512 and is_multicast_legal[i] and num_groups == 1: + best_tma_multicast_config = (2, i == 'A') + break # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast - assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) + num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms - return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -158,12 +172,14 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms) args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 415fc67..bffe137 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated grouped GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, - is_grouped_contiguous=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) args = (lhs, lhs_scales, rhs, rhs_scales, out, m_indices, m, num_groups, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': 'GroupedContiguous'}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), @@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) # Extra checks for TMA store if num_groups > 1 and m > block_m: @@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': 'GroupedMasked'}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),