fix cutlass_3x_gemm_fp8_blockwise on sm103a (#32224)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <climits>
|
||||
#include "cuda_runtime.h"
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
@@ -31,12 +32,63 @@ int32_t get_sm_version_num();
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm75_to_sm80 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm[75, 80).\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm80_to_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm[80, 89).\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm89_to_sm90 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm[89, 90).\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm >= 90.\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -45,18 +97,43 @@ template <typename Kernel>
|
||||
struct enable_sm90_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ == 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm90.\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm100_only : Kernel {
|
||||
struct enable_sm100f_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm100f.\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm100a_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ == 1000
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm100a.\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -65,8 +142,13 @@ template <typename Kernel>
|
||||
struct enable_sm120_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
|
||||
#if defined __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ == 1200
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#else
|
||||
printf("This kernel only supports sm120.\n");
|
||||
asm("trap;");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -141,8 +141,8 @@ struct cutlass_3x_gemm_sm100 {
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using GemmKernel = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
@@ -202,8 +202,8 @@ struct cutlass_3x_gemm_sm120 {
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@@ -123,7 +123,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
MainloopScheduler
|
||||
>::CollectiveOp>;
|
||||
|
||||
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
using KernelType = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
@@ -90,8 +90,8 @@ struct cutlass_3x_gemm_sm100_fp8 {
|
||||
// -----------------------------------------------------------
|
||||
// Kernel definition
|
||||
// -----------------------------------------------------------
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using GemmKernel = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias>
|
||||
|
||||
@@ -36,41 +36,6 @@ using namespace cute;
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm75_to_sm80 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm80_to_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm89_to_sm90 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
|
||||
@@ -50,7 +50,7 @@ struct sm89_fp8_config_default {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -58,7 +58,7 @@ struct sm89_fp8_config_default {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -67,7 +67,7 @@ struct sm89_fp8_config_default {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -100,7 +100,7 @@ struct sm89_fp8_config_M256 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -108,7 +108,7 @@ struct sm89_fp8_config_M256 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -141,7 +141,7 @@ struct sm89_fp8_config_M128 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -150,7 +150,7 @@ struct sm89_fp8_config_M128 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -158,7 +158,7 @@ struct sm89_fp8_config_M128 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -191,7 +191,7 @@ struct sm89_fp8_config_M64 {
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -201,7 +201,7 @@ struct sm89_fp8_config_M64 {
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -211,7 +211,7 @@ struct sm89_fp8_config_M64 {
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -244,7 +244,7 @@ struct sm89_fp8_config_M32 {
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -253,7 +253,7 @@ struct sm89_fp8_config_M32 {
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -262,7 +262,7 @@ struct sm89_fp8_config_M32 {
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -296,7 +296,7 @@ struct sm89_fp8_config_M16 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
@@ -305,7 +305,7 @@ struct sm89_fp8_config_M16 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
@@ -314,7 +314,7 @@ struct sm89_fp8_config_M16 {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
|
||||
@@ -48,7 +48,7 @@ struct sm89_int8_config_default {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -56,7 +56,7 @@ struct sm89_int8_config_default {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -64,7 +64,7 @@ struct sm89_int8_config_default {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -72,7 +72,7 @@ struct sm89_int8_config_default {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -104,7 +104,7 @@ struct sm89_int8_config_M256 {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -112,7 +112,7 @@ struct sm89_int8_config_M256 {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -120,7 +120,7 @@ struct sm89_int8_config_M256 {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -128,7 +128,7 @@ struct sm89_int8_config_M256 {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -160,7 +160,7 @@ struct sm89_int8_config_M128 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -169,7 +169,7 @@ struct sm89_int8_config_M128 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -178,7 +178,7 @@ struct sm89_int8_config_M128 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -210,7 +210,7 @@ struct sm89_int8_config_M64 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -219,7 +219,7 @@ struct sm89_int8_config_M64 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -251,7 +251,7 @@ struct sm89_int8_config_M32 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -260,7 +260,7 @@ struct sm89_int8_config_M32 {
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -292,7 +292,7 @@ struct sm89_int8_config_M16 {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
@@ -300,7 +300,7 @@ struct sm89_int8_config_M16 {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
Reference in New Issue
Block a user