fix cutlass_3x_gemm_fp8_blockwise on sm103a (#32224)

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Lain
2026-02-02 11:47:46 -08:00
committed by GitHub
parent 0130223bd9
commit 089cd4f002
7 changed files with 129 additions and 82 deletions

View File

@@ -3,7 +3,8 @@
#include "cutlass/cutlass.h"
#include <climits>
#include "cuda_runtime.h"
#include <iostream>
#include <cstdio>
#include <cstdlib>
/**
* Helper function for checking CUTLASS errors
@@ -31,12 +32,63 @@ int32_t get_sm_version_num();
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm[75, 80).\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm[80, 89).\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm[89, 90).\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm >= 90.\n");
asm("trap;");
#endif
#endif
}
};
@@ -45,18 +97,43 @@ template <typename Kernel>
struct enable_sm90_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 900
Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm90.\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm100_only : Kernel {
struct enable_sm100f_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030
Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm100f.\n");
asm("trap;");
#endif
#endif
}
};
template <typename Kernel>
struct enable_sm100a_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm100a.\n");
asm("trap;");
#endif
#endif
}
};
@@ -65,8 +142,13 @@ template <typename Kernel>
struct enable_sm120_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
#if defined __CUDA_ARCH__
#if __CUDA_ARCH__ == 1200
Kernel::operator()(std::forward<Args>(args)...);
#else
printf("This kernel only supports sm120.\n");
asm("trap;");
#endif
#endif
}
};

View File

@@ -141,8 +141,8 @@ struct cutlass_3x_gemm_sm100 {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using GemmKernel = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};
template <typename ElementAB_, typename ElementD_,
@@ -202,8 +202,8 @@ struct cutlass_3x_gemm_sm120 {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};
} // namespace vllm

View File

@@ -123,7 +123,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
MainloopScheduler
>::CollectiveOp>;
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
using KernelType = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};

View File

@@ -90,8 +90,8 @@ struct cutlass_3x_gemm_sm100_fp8 {
// -----------------------------------------------------------
// Kernel definition
// -----------------------------------------------------------
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using GemmKernel = enable_sm100f_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};
template <typename InType, typename OutType, bool EnableBias>

View File

@@ -36,41 +36,6 @@ using namespace cute;
*/
namespace vllm {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,

View File

@@ -50,7 +50,7 @@ struct sm89_fp8_config_default {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -58,7 +58,7 @@ struct sm89_fp8_config_default {
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -67,7 +67,7 @@ struct sm89_fp8_config_default {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -100,7 +100,7 @@ struct sm89_fp8_config_M256 {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -108,7 +108,7 @@ struct sm89_fp8_config_M256 {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -141,7 +141,7 @@ struct sm89_fp8_config_M128 {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -150,7 +150,7 @@ struct sm89_fp8_config_M128 {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -158,7 +158,7 @@ struct sm89_fp8_config_M128 {
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -191,7 +191,7 @@ struct sm89_fp8_config_M64 {
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -201,7 +201,7 @@ struct sm89_fp8_config_M64 {
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -211,7 +211,7 @@ struct sm89_fp8_config_M64 {
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -244,7 +244,7 @@ struct sm89_fp8_config_M32 {
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -253,7 +253,7 @@ struct sm89_fp8_config_M32 {
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -262,7 +262,7 @@ struct sm89_fp8_config_M32 {
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -296,7 +296,7 @@ struct sm89_fp8_config_M16 {
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
@@ -305,7 +305,7 @@ struct sm89_fp8_config_M16 {
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
@@ -314,7 +314,7 @@ struct sm89_fp8_config_M16 {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,

View File

@@ -48,7 +48,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -56,7 +56,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -64,7 +64,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -72,7 +72,7 @@ struct sm89_int8_config_default {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -104,7 +104,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -112,7 +112,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -120,7 +120,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -128,7 +128,7 @@ struct sm89_int8_config_M256 {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -160,7 +160,7 @@ struct sm89_int8_config_M128 {
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -169,7 +169,7 @@ struct sm89_int8_config_M128 {
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -178,7 +178,7 @@ struct sm89_int8_config_M128 {
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -210,7 +210,7 @@ struct sm89_int8_config_M64 {
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -219,7 +219,7 @@ struct sm89_int8_config_M64 {
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -251,7 +251,7 @@ struct sm89_int8_config_M32 {
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -260,7 +260,7 @@ struct sm89_int8_config_M32 {
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -292,7 +292,7 @@ struct sm89_int8_config_M16 {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
@@ -300,7 +300,7 @@ struct sm89_int8_config_M16 {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);