diff --git a/setup.py b/setup.py index 916f760..22a65ba 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import shutil import subprocess from setuptools import find_packages from setuptools.command.build_py import build_py -from torch.utils.cpp_extension import CppExtension, CUDA_HOME +from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME current_dir = os.path.dirname(os.path.realpath(__file__)) cxx_flags = ['-std=c++20', '-O3', '-fPIC', '-Wno-psabi'] @@ -89,7 +89,7 @@ if __name__ == '__main__': ] }, ext_modules=[ - CppExtension(name='deep_gemm_cpp', + CUDAExtension(name='deep_gemm_cpp', sources=sources, include_dirs=build_include_dirs, libraries=build_libraries,