2025-07-18 11:32:22 +08:00
|
|
|
import os
|
2025-08-15 18:32:35 +08:00
|
|
|
import subprocess
|
2025-02-25 22:52:41 +08:00
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# Set some default environment provided at setup
|
|
|
|
|
try:
|
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
|
|
|
from .envs import persistent_envs
|
|
|
|
|
for key, value in persistent_envs.items():
|
|
|
|
|
if key not in os.environ:
|
|
|
|
|
os.environ[key] = value
|
|
|
|
|
except ImportError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Configs
|
2025-08-15 18:32:35 +08:00
|
|
|
import deep_gemm_cpp
|
2025-07-18 11:32:22 +08:00
|
|
|
from deep_gemm_cpp import (
|
|
|
|
|
set_num_sms,
|
2025-08-02 19:52:22 -07:00
|
|
|
get_num_sms,
|
|
|
|
|
set_tc_util,
|
|
|
|
|
get_tc_util,
|
2025-02-25 22:52:41 +08:00
|
|
|
)
|
2025-07-18 11:32:22 +08:00
|
|
|
|
|
|
|
|
# Kernels
|
|
|
|
|
from deep_gemm_cpp import (
|
2025-08-02 19:52:22 -07:00
|
|
|
# FP8 GEMMs
|
2025-07-18 11:32:22 +08:00
|
|
|
fp8_gemm_nt, fp8_gemm_nn,
|
|
|
|
|
fp8_gemm_tn, fp8_gemm_tt,
|
2025-09-25 16:19:07 +08:00
|
|
|
fp8_gemm_nt_skip_head_mid,
|
2025-07-18 11:32:22 +08:00
|
|
|
m_grouped_fp8_gemm_nt_contiguous,
|
|
|
|
|
m_grouped_fp8_gemm_nn_contiguous,
|
2025-08-15 18:32:35 +08:00
|
|
|
m_grouped_fp8_gemm_nt_masked,
|
2025-09-25 16:19:07 +08:00
|
|
|
k_grouped_fp8_gemm_nt_contiguous,
|
2025-08-02 19:52:22 -07:00
|
|
|
k_grouped_fp8_gemm_tn_contiguous,
|
2025-08-15 18:32:35 +08:00
|
|
|
# BF16 GEMMs
|
|
|
|
|
bf16_gemm_nt, bf16_gemm_nn,
|
|
|
|
|
bf16_gemm_tn, bf16_gemm_tt,
|
|
|
|
|
m_grouped_bf16_gemm_nt_contiguous,
|
|
|
|
|
m_grouped_bf16_gemm_nt_masked,
|
2025-09-25 16:19:07 +08:00
|
|
|
# cuBLASLt GEMMs
|
|
|
|
|
cublaslt_gemm_nt, cublaslt_gemm_nn,
|
|
|
|
|
cublaslt_gemm_tn, cublaslt_gemm_tt,
|
|
|
|
|
# Einsum kernels
|
|
|
|
|
einsum,
|
2025-08-02 19:52:22 -07:00
|
|
|
# Layout kernels
|
|
|
|
|
transform_sf_into_required_layout
|
2025-07-18 11:32:22 +08:00
|
|
|
)
|
|
|
|
|
|
2025-08-15 18:32:35 +08:00
|
|
|
# Some alias for legacy supports
|
|
|
|
|
# TODO: remove these later
|
|
|
|
|
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
|
|
|
|
|
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
|
|
|
|
|
2025-07-18 11:32:22 +08:00
|
|
|
# Some utils
|
|
|
|
|
from . import testing
|
|
|
|
|
from . import utils
|
|
|
|
|
from .utils import *
|
2025-08-15 18:32:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize CPP modules
|
|
|
|
|
def _find_cuda_home() -> str:
|
|
|
|
|
# TODO: reuse PyTorch API later
|
|
|
|
|
# For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks
|
|
|
|
|
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
|
|
|
|
if cuda_home is None:
|
|
|
|
|
# noinspection PyBroadException
|
|
|
|
|
try:
|
|
|
|
|
with open(os.devnull, 'w') as devnull:
|
|
|
|
|
nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n')
|
|
|
|
|
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
|
|
|
|
except Exception:
|
|
|
|
|
cuda_home = '/usr/local/cuda'
|
|
|
|
|
if not os.path.exists(cuda_home):
|
|
|
|
|
cuda_home = None
|
|
|
|
|
assert cuda_home is not None
|
|
|
|
|
return cuda_home
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deep_gemm_cpp.init(
|
|
|
|
|
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
|
|
|
|
|
_find_cuda_home() # CUDA home
|
|
|
|
|
)
|