[Kernel] Add GPU architecture guards to the CUTLASS w8a8 kernels to reduce binary size (#5157)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Tyler Michael Smith
2024-06-05 13:44:15 -04:00
committed by GitHub
parent 02cc3b51a7
commit ccd4f129e8
2 changed files with 87 additions and 37 deletions

View File

@@ -56,6 +56,21 @@ uint32_t next_pow_2(uint32_t const num) {
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename ElementAB_, typename ElementD_, typename TileShape,
typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
@@ -126,9 +141,9 @@ struct cutlass_3x_gemm {
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = cutlass::gemm::kernel::GemmUniversal<
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};