Files
DeepGEMM/deep_gemm/__init__.py
Ray Wang 9da4a23561 Add more GPU architectures support (#112)
* Add more GPU architectures support

* Update layout.py

* Optimize performance, Add SM90 support, Add 1D2D SM100 support

* Add fmtlib submodule at commit 553ec11

---------

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
2025-07-18 11:32:22 +08:00

42 lines
957 B
Python

import os
import torch
import torch.utils.cpp_extension
# Set some default environment provided at setup
try:
# noinspection PyUnresolvedReferences
from .envs import persistent_envs
for key, value in persistent_envs.items():
if key not in os.environ:
os.environ[key] = value
except ImportError:
pass
# Import functions from the CPP module
import deep_gemm_cpp
deep_gemm_cpp.init(
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
torch.utils.cpp_extension.CUDA_HOME # CUDA home
)
# Configs
from deep_gemm_cpp import (
set_num_sms,
get_num_sms
)
# Kernels
from deep_gemm_cpp import (
fp8_gemm_nt, fp8_gemm_nn,
fp8_gemm_tn, fp8_gemm_tt,
m_grouped_fp8_gemm_nt_contiguous,
m_grouped_fp8_gemm_nn_contiguous,
fp8_m_grouped_gemm_nt_masked,
k_grouped_fp8_gemm_tn_contiguous
)
# Some utils
from . import testing
from . import utils
from .utils import *