chore: Build and store bdist wheels (#181)
* build: Minor tweeks for wheel build Signed-off-by: oliver könig <okoenig@nvidia.com> * ci: Workflows for wheel build Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * build: Add CachedWheel Signed-off-by: oliver könig <okoenig@nvidia.com> * add version to init Signed-off-by: oliver könig <okoenig@nvidia.com> * revert Signed-off-by: oliver könig <okoenig@nvidia.com> * revert Signed-off-by: oliver könig <okoenig@nvidia.com> * revert Signed-off-by: oliver könig <okoenig@nvidia.com> * v2 Signed-off-by: oliver könig <okoenig@nvidia.com> * update Signed-off-by: oliver könig <okoenig@nvidia.com> * test Signed-off-by: oliver könig <okoenig@nvidia.com> * from packaging.version import parse Signed-off-by: oliver könig <okoenig@nvidia.com> * local version Signed-off-by: oliver könig <okoenig@nvidia.com> * remove file Signed-off-by: oliver könig <okoenig@nvidia.com> * revert Signed-off-by: oliver könig <okoenig@nvidia.com> * Updates and lint * revert missing cudaextension args Signed-off-by: oliver könig <okoenig@nvidia.com> * Add timeout * fix DG settings Signed-off-by: oliver könig <okoenig@nvidia.com> * DG_USE_LOCAL_VERSION Signed-off-by: oliver könig <okoenig@nvidia.com> * Update version * Detect local changes * Minor fix * Revert CUTLASS * Unify options --------- Signed-off-by: oliver könig <okoenig@nvidia.com> Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
135
setup.py
135
setup.py
@@ -1,15 +1,36 @@
|
||||
import ast
|
||||
import os
|
||||
import setuptools
|
||||
import re
|
||||
import shutil
|
||||
import setuptools
|
||||
import subprocess
|
||||
import sys
|
||||
import torch
|
||||
import platform
|
||||
import urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from setuptools import find_packages
|
||||
from setuptools.command.build_py import build_py
|
||||
from packaging.version import parse
|
||||
from pathlib import Path
|
||||
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
DG_SKIP_CUDA_BUILD = int(os.getenv('DG_SKIP_CUDA_BUILD', '0')) == 1
|
||||
DG_FORCE_BUILD = int(os.getenv('DG_FORCE_BUILD', '0')) == 1
|
||||
DG_USE_LOCAL_VERSION = int(os.getenv('DG_USE_LOCAL_VERSION', '1')) == 1
|
||||
DG_JIT_USE_RUNTIME_API = int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')) == 1
|
||||
|
||||
# Compiler flags
|
||||
cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations',
|
||||
f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}']
|
||||
if DG_JIT_USE_RUNTIME_API:
|
||||
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
|
||||
|
||||
# Sources
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sources = ['csrc/python_api.cpp']
|
||||
build_include_dirs = [
|
||||
f'{CUDA_HOME}/include',
|
||||
@@ -28,9 +49,68 @@ third_party_include_dirs = [
|
||||
'third-party/cutlass/include/cutlass',
|
||||
]
|
||||
|
||||
# Use runtime API
|
||||
if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')):
|
||||
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
|
||||
# Release
|
||||
base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}'
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open(Path(current_dir) / 'deep_gemm' / '__init__.py', 'r') as f:
|
||||
version_match = re.search(r'^__version__\s*=\s*(.*)$', f.read(), re.MULTILINE)
|
||||
public_version = ast.literal_eval(version_match.group(1))
|
||||
|
||||
revision = ''
|
||||
if DG_USE_LOCAL_VERSION:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
status_cmd = ['git', 'status', '--porcelain']
|
||||
status_output = subprocess.check_output(status_cmd).decode('ascii').strip()
|
||||
if status_output:
|
||||
print(f'Warning: Git working directory is not clean. Uncommitted changes:\n{status_output}')
|
||||
assert False, 'Git working directory is not clean'
|
||||
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except:
|
||||
revision = '+local'
|
||||
return f'{public_version}{revision}'
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith('linux'):
|
||||
return f'linux_{platform.uname().machine}'
|
||||
else:
|
||||
raise ValueError('Unsupported platform: {}'.format(sys.platform))
|
||||
|
||||
|
||||
def get_wheel_url():
|
||||
torch_version = parse(torch.__version__)
|
||||
torch_version = f'{torch_version.major}.{torch_version.minor}'
|
||||
python_version = f'cp{sys.version_info.major}{sys.version_info.minor}'
|
||||
platform_name = get_platform()
|
||||
deep_gemm_version = get_package_version()
|
||||
cxx11_abi = int(torch._C._GLIBCXX_USE_CXX11_ABI)
|
||||
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
# We're using the CUDA version used to build torch, not the one currently installed
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
cuda_version = f'{cuda_version.major}'
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f'deep_gemm-{deep_gemm_version}+cu{cuda_version}-torch{torch_version}-cxx11abi{cxx11_abi}-{python_version}-{platform_name}.whl'
|
||||
wheel_url = base_wheel_url.format(tag_name=f'v{deep_gemm_version}', wheel_name=wheel_filename)
|
||||
return wheel_url, wheel_filename
|
||||
|
||||
|
||||
def get_ext_modules():
|
||||
if DG_SKIP_CUDA_BUILD:
|
||||
return []
|
||||
|
||||
return [CUDAExtension(name='deep_gemm_cpp',
|
||||
sources=sources,
|
||||
include_dirs=build_include_dirs,
|
||||
libraries=build_libraries,
|
||||
library_dirs=build_library_dirs,
|
||||
extra_compile_args=cxx_flags)]
|
||||
|
||||
|
||||
class CustomBuildPy(build_py):
|
||||
@@ -72,18 +152,37 @@ class CustomBuildPy(build_py):
|
||||
shutil.copytree(src_dir, dst_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except:
|
||||
revision = ''
|
||||
class CachedWheelsCommand(_bdist_wheel):
|
||||
def run(self):
|
||||
if DG_FORCE_BUILD or DG_USE_LOCAL_VERSION:
|
||||
return super().run()
|
||||
|
||||
wheel_url, wheel_filename = get_wheel_url()
|
||||
print(f'Try to download wheel from URL: {wheel_url}')
|
||||
try:
|
||||
with urllib.request.urlopen(wheel_url, timeout=1) as response:
|
||||
with open(wheel_filename, 'wb') as out_file:
|
||||
data = response.read()
|
||||
out_file.write(data)
|
||||
|
||||
# Make the archive
|
||||
if not os.path.exists(self.dist_dir):
|
||||
os.makedirs(self.dist_dir)
|
||||
impl_tag, abi_tag, plat_tag = self.get_tag()
|
||||
archive_basename = f'{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}'
|
||||
wheel_path = os.path.join(self.dist_dir, archive_basename + '.whl')
|
||||
os.rename(wheel_filename, wheel_path)
|
||||
except (urllib.error.HTTPError, urllib.error.URLError):
|
||||
print('Precompiled wheel not found. Building from source...')
|
||||
# If the wheel could not be downloaded, build from source
|
||||
super().run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# noinspection PyTypeChecker
|
||||
setuptools.setup(
|
||||
name='deep_gemm',
|
||||
version='2.1.0' + revision,
|
||||
version=get_package_version(),
|
||||
packages=find_packages('.'),
|
||||
package_data={
|
||||
'deep_gemm': [
|
||||
@@ -92,16 +191,10 @@ if __name__ == '__main__':
|
||||
'include/cutlass/**/*',
|
||||
]
|
||||
},
|
||||
ext_modules=[
|
||||
CUDAExtension(name='deep_gemm_cpp',
|
||||
sources=sources,
|
||||
include_dirs=build_include_dirs,
|
||||
libraries=build_libraries,
|
||||
library_dirs=build_library_dirs,
|
||||
extra_compile_args=cxx_flags)
|
||||
],
|
||||
ext_modules=get_ext_modules(),
|
||||
zip_safe=False,
|
||||
cmdclass={
|
||||
'build_py': CustomBuildPy,
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user