Files
nvfp4-megamoe-kernel/third_party/DeepGEMM/setup.py

215 lines
7.8 KiB
Python

import ast
import os
import re
import shutil
import setuptools
import subprocess
import sys
import torch
import platform
import urllib
import urllib.error
import urllib.request
from setuptools import find_packages
from setuptools.command.build_py import build_py
from packaging.version import parse
from pathlib import Path
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from scripts.generate_pyi import generate_pyi_file
DG_SKIP_CUDA_BUILD = int(os.getenv('DG_SKIP_CUDA_BUILD', '0')) == 1
DG_FORCE_BUILD = int(os.getenv('DG_FORCE_BUILD', '0')) == 1
DG_USE_LOCAL_VERSION = int(os.getenv('DG_USE_LOCAL_VERSION', '1')) == 1
DG_JIT_USE_RUNTIME_API = int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')) == 1
# Compiler flags
cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations',
f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}']
if DG_JIT_USE_RUNTIME_API:
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
# Sources
current_dir = os.path.dirname(os.path.realpath(__file__))
sources = ['csrc/python_api.cpp']
build_include_dirs = [
f'{CUDA_HOME}/include',
f'{CUDA_HOME}/include/cccl',
'deep_gemm/include',
'third-party/cutlass/include',
'third-party/fmt/include',
]
build_libraries = ['cudart', 'nvrtc']
build_library_dirs = [f'{CUDA_HOME}/lib64']
third_party_include_dirs = [
'third-party/cutlass/include/cute',
'third-party/cutlass/include/cutlass',
]
# Release
base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}'
def get_package_version():
with open(Path(current_dir) / 'deep_gemm' / '__init__.py', 'r') as f:
version_match = re.search(r'^__version__\s*=\s*(.*)$', f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1))
revision = ''
if DG_USE_LOCAL_VERSION:
# noinspection PyBroadException
try:
status_cmd = ['git', 'status', '--porcelain']
status_output = subprocess.check_output(status_cmd).decode('ascii').strip()
if status_output:
print(f'Warning: Git working directory is not clean. Uncommitted changes:\n{status_output}')
assert False, 'Git working directory is not clean'
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception:
revision = '+local'
return f'{public_version}{revision}'
def get_platform():
if sys.platform.startswith('linux'):
return f'linux_{platform.uname().machine}'
else:
raise ValueError('Unsupported platform: {}'.format(sys.platform))
def get_wheel_url():
torch_version = parse(torch.__version__)
torch_version = f'{torch_version.major}.{torch_version.minor}'
python_version = f'cp{sys.version_info.major}{sys.version_info.minor}'
platform_name = get_platform()
deep_gemm_version = get_package_version()
cxx11_abi = int(torch._C._GLIBCXX_USE_CXX11_ABI)
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
cuda_version = parse(torch.version.cuda)
cuda_version = f'{cuda_version.major}'
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'deep_gemm-{deep_gemm_version}+cu{cuda_version}-torch{torch_version}-cxx11abi{cxx11_abi}-{python_version}-{platform_name}.whl'
wheel_url = base_wheel_url.format(tag_name=f'v{deep_gemm_version}', wheel_name=wheel_filename)
return wheel_url, wheel_filename
def get_ext_modules():
if DG_SKIP_CUDA_BUILD:
return []
return [CUDAExtension(name='deep_gemm._C',
sources=sources,
include_dirs=build_include_dirs,
libraries=build_libraries,
library_dirs=build_library_dirs,
extra_compile_args=cxx_flags)]
class CustomBuildPy(build_py):
def run(self):
# First, prepare the include directories
self.prepare_includes()
# Second, make clusters' cache setting default into `envs.py`
self.generate_default_envs()
# Third, generate and copy .pyi file to build root directory
self.generate_pyi_file()
# Finally, run the regular build
build_py.run(self)
def generate_pyi_file(self):
generate_pyi_file(name='_C', root='./csrc', output_dir='./stubs')
pyi_source = os.path.join(current_dir, 'stubs', '_C.pyi')
pyi_target = os.path.join(self.build_lib, 'deep_gemm', '_C.pyi')
if os.path.exists(pyi_source):
print(f"Copying .pyi file from {pyi_source} to {pyi_target}")
os.makedirs(os.path.dirname(pyi_target), exist_ok=True)
shutil.copy2(pyi_source, pyi_target)
else:
print(f"Warning: .pyi file not found at {pyi_source}")
def generate_default_envs(self):
code = '# Pre-installed environment variables\n'
code += 'persistent_envs = dict()\n'
for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'):
code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else ''
with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f:
f.write(code)
def prepare_includes(self):
# Create temporary build directory instead of modifying package directory
build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include')
os.makedirs(build_include_dir, exist_ok=True)
# Copy third-party includes to the build directory
for d in third_party_include_dirs:
dirname = d.split('/')[-1]
src_dir = os.path.join(current_dir, d)
dst_dir = os.path.join(build_include_dir, dirname)
# Remove existing directory if it exists
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# Copy the directory
shutil.copytree(src_dir, dst_dir)
class CachedWheelsCommand(_bdist_wheel):
def run(self):
if DG_FORCE_BUILD or DG_USE_LOCAL_VERSION:
return super().run()
wheel_url, wheel_filename = get_wheel_url()
print(f'Try to download wheel from URL: {wheel_url}')
# noinspection PyBroadException
try:
with urllib.request.urlopen(wheel_url, timeout=1) as response:
with open(wheel_filename, 'wb') as out_file:
data = response.read()
out_file.write(data)
# Make the archive
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f'{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}'
wheel_path = os.path.join(self.dist_dir, archive_basename + '.whl')
os.rename(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError):
print('Precompiled wheel not found. Building from source...')
# If the wheel could not be downloaded, build from source
super().run()
if __name__ == '__main__':
# noinspection PyTypeChecker
setuptools.setup(
name='deep_gemm',
version=get_package_version(),
packages=find_packages('.'),
package_data={
'deep_gemm': [
'include/deep_gemm/**/*',
'include/cute/**/*',
'include/cutlass/**/*',
]
},
ext_modules=get_ext_modules(),
zip_safe=False,
cmdclass={
'build_py': CustomBuildPy,
'bdist_wheel': CachedWheelsCommand,
},
)