[Kernel] Add GPU architecture guards to the CUTLASS w8a8 kernels to reduce binary size (#5157)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
committed by
GitHub
parent
02cc3b51a7
commit
ccd4f129e8
@@ -56,6 +56,21 @@ uint32_t next_pow_2(uint32_t const num) {
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
||||
typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
@@ -126,9 +141,9 @@ struct cutlass_3x_gemm {
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = cutlass::gemm::kernel::GemmUniversal<
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>;
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user