[Kernel] Add enable_sm120_or_later for SM121 (DGX Spark) CUTLASS support (#33517)
Signed-off-by: code4me2 <velvetmoon222999@gmail.com>
This commit is contained in:
@@ -152,3 +152,14 @@ struct enable_sm120_only : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// SM12x family includes SM120 (RTX 5090) and SM121 (DGX Spark GB10)
|
||||
template <typename Kernel>
|
||||
struct enable_sm120_family : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && (__CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300)
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -103,7 +103,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
// SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
|
||||
using KernelType = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
Reference in New Issue
Block a user