Embed DeepGEMM source (not submodule) for SM100 raw CUDA GEMM primitives
This commit is contained in:
28
third_party/DeepGEMM/csrc/python_api.cpp
vendored
Normal file
28
third_party/DeepGEMM/csrc/python_api.cpp
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "apis/attention.hpp"
|
||||
#include "apis/einsum.hpp"
|
||||
#include "apis/hyperconnection.hpp"
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/mega.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME _C
|
||||
#endif
|
||||
|
||||
// ReSharper disable once CppParameterMayBeConstPtrOrRef
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// TODO: make SM80 incompatible issues raise errors
|
||||
deep_gemm::attention::register_apis(m);
|
||||
deep_gemm::einsum::register_apis(m);
|
||||
deep_gemm::hyperconnection::register_apis(m);
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::mega::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
}
|
||||
Reference in New Issue
Block a user