diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 1d5280139..91c215071 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -152,3 +152,14 @@ struct enable_sm120_only : Kernel { #endif } }; + +// SM12x family includes SM120 (RTX 5090) and SM121 (DGX Spark GB10) +template +struct enable_sm120_family : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && (__CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300) + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index 811741aee..f255b27a1 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -103,7 +103,8 @@ struct cutlass_3x_gemm_fp8_blockwise { MainloopScheduler >::CollectiveOp; - using KernelType = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {};