[UX] Integrate DeepGEMM into vLLM wheel via CMake (#37980)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -12,6 +12,9 @@ vllm/third_party/triton_kernels/*
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# DeepGEMM vendored package built from source
|
||||
vllm/third_party/deep_gemm/
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
|
||||
@@ -1222,6 +1222,7 @@ endif()
|
||||
|
||||
# For CUDA we also build and ship some external projects.
|
||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(cmake/external_projects/deepgemm.cmake)
|
||||
include(cmake/external_projects/flashmla.cmake)
|
||||
include(cmake/external_projects/qutlass.cmake)
|
||||
|
||||
|
||||
151
cmake/external_projects/deepgemm.cmake
Normal file
151
cmake/external_projects/deepgemm.cmake
Normal file
@@ -0,0 +1,151 @@
|
||||
include(FetchContent)
|
||||
|
||||
# If DEEPGEMM_SRC_DIR is set, DeepGEMM is built from that directory
|
||||
# instead of downloading.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{DEEPGEMM_SRC_DIR})
|
||||
set(DEEPGEMM_SRC_DIR $ENV{DEEPGEMM_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(DEEPGEMM_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
SOURCE_DIR ${DEEPGEMM_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
# This ref should be kept in sync with tools/install_deepgemm.sh
|
||||
FetchContent_Declare(
|
||||
deepgemm
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
|
||||
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03
|
||||
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
endif()
|
||||
|
||||
# Use FetchContent_Populate (not MakeAvailable) to avoid processing
|
||||
# DeepGEMM's own CMakeLists.txt which has incompatible find_package calls.
|
||||
FetchContent_GetProperties(deepgemm)
|
||||
if(NOT deepgemm_POPULATED)
|
||||
FetchContent_Populate(deepgemm)
|
||||
endif()
|
||||
message(STATUS "DeepGEMM is available at ${deepgemm_SOURCE_DIR}")
|
||||
|
||||
# DeepGEMM requires CUDA 12.3+ for SM90, 12.9+ for SM100
|
||||
set(DEEPGEMM_SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND DEEPGEMM_SUPPORT_ARCHS "9.0a")
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
|
||||
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0f")
|
||||
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0a")
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(DEEPGEMM_ARCHS
|
||||
"${DEEPGEMM_SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
|
||||
if(DEEPGEMM_ARCHS)
|
||||
message(STATUS "DeepGEMM CUDA architectures: ${DEEPGEMM_ARCHS}")
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
#
|
||||
# Build the _C pybind11 extension from DeepGEMM's C++ source.
|
||||
# This is a CXX-only module — CUDA kernels are JIT-compiled at runtime.
|
||||
#
|
||||
Python_add_library(_deep_gemm_C MODULE WITH_SOABI
|
||||
"${deepgemm_SOURCE_DIR}/csrc/python_api.cpp")
|
||||
|
||||
# The pybind11 module name must be _C to match DeepGEMM's Python imports.
|
||||
set_target_properties(_deep_gemm_C PROPERTIES OUTPUT_NAME "_C")
|
||||
|
||||
target_compile_definitions(_deep_gemm_C PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=_C")
|
||||
|
||||
target_include_directories(_deep_gemm_C PRIVATE
|
||||
"${deepgemm_SOURCE_DIR}/csrc"
|
||||
"${deepgemm_SOURCE_DIR}/deep_gemm/include"
|
||||
"${deepgemm_SOURCE_DIR}/third-party/cutlass/include"
|
||||
"${deepgemm_SOURCE_DIR}/third-party/cutlass/tools/util/include"
|
||||
"${deepgemm_SOURCE_DIR}/third-party/fmt/include")
|
||||
|
||||
target_compile_options(_deep_gemm_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-std=c++17>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-O3>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-Wno-psabi>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-Wno-deprecated-declarations>)
|
||||
|
||||
# torch_python is required because DeepGEMM uses pybind11 type casters
|
||||
# for at::Tensor (via PYBIND11_MODULE), unlike vLLM's own extensions which
|
||||
# use torch::Library custom ops.
|
||||
find_library(TORCH_PYTHON_LIBRARY torch_python
|
||||
PATHS "${TORCH_INSTALL_PREFIX}/lib"
|
||||
REQUIRED)
|
||||
|
||||
target_link_libraries(_deep_gemm_C PRIVATE
|
||||
torch ${TORCH_LIBRARIES} "${TORCH_PYTHON_LIBRARY}"
|
||||
CUDA::cudart CUDA::nvrtc)
|
||||
|
||||
# Install the shared library into the vendored package directory
|
||||
install(TARGETS _deep_gemm_C
|
||||
LIBRARY DESTINATION vllm/third_party/deep_gemm
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
#
|
||||
# Vendor DeepGEMM Python package files
|
||||
#
|
||||
install(FILES
|
||||
"${deepgemm_SOURCE_DIR}/deep_gemm/__init__.py"
|
||||
DESTINATION vllm/third_party/deep_gemm
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/utils/"
|
||||
DESTINATION vllm/third_party/deep_gemm/utils
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/testing/"
|
||||
DESTINATION vllm/third_party/deep_gemm/testing
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/legacy/"
|
||||
DESTINATION vllm/third_party/deep_gemm/legacy
|
||||
COMPONENT _deep_gemm_C
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
|
||||
# Generate envs.py (normally generated by DeepGEMM's setup.py build step)
|
||||
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||
"# Pre-installed environment variables\npersistent_envs = dict()\n")
|
||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||
DESTINATION vllm/third_party/deep_gemm
|
||||
RENAME envs.py
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
#
|
||||
# Install include files needed for JIT compilation at runtime.
|
||||
# The JIT compiler finds these relative to the package directory.
|
||||
#
|
||||
|
||||
# DeepGEMM's own CUDA headers
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/include/"
|
||||
DESTINATION vllm/third_party/deep_gemm/include
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
# CUTLASS and CuTe headers (vendored for JIT, separate from vLLM's CUTLASS)
|
||||
install(DIRECTORY "${deepgemm_SOURCE_DIR}/third-party/cutlass/include/"
|
||||
DESTINATION vllm/third_party/deep_gemm/include
|
||||
COMPONENT _deep_gemm_C)
|
||||
|
||||
else()
|
||||
message(STATUS "DeepGEMM will not compile: "
|
||||
"unsupported CUDA architecture ${CUDA_ARCHS}")
|
||||
# Create empty target so setup.py doesn't fail on unsupported systems
|
||||
add_custom_target(_deep_gemm_C)
|
||||
endif()
|
||||
@@ -315,7 +315,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
#################### CSRC BUILD IMAGE ####################
|
||||
|
||||
#################### EXTENSIONS BUILD IMAGE ####################
|
||||
# Build DeepGEMM, DeepEP - runs in PARALLEL with csrc-build
|
||||
# Build DeepEP - runs in PARALLEL with csrc-build
|
||||
# This stage is independent and doesn't affect csrc cache
|
||||
FROM base AS extensions-build
|
||||
ARG CUDA_VERSION
|
||||
@@ -327,21 +327,6 @@ ENV UV_LINK_MODE=copy
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Build DeepGEMM wheel
|
||||
# Default moved here from tools/install_deepgemm.sh for centralized version management
|
||||
ARG DEEPGEMM_GIT_REF=477618cd51baffca09c4b0b87e97c03fe827ef03
|
||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
mkdir -p /tmp/deepgemm/dist && \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh \
|
||||
--cuda-version "${CUDA_VERSION}" \
|
||||
${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} \
|
||||
--wheel-dir /tmp/deepgemm/dist || \
|
||||
echo "DeepGEMM build skipped (CUDA version requirement not met)"
|
||||
|
||||
# Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped
|
||||
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
|
||||
|
||||
# Build DeepEP wheels
|
||||
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
||||
# Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management
|
||||
@@ -426,7 +411,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
|
||||
|
||||
# Copy extension wheels from extensions-build stage for later use
|
||||
COPY --from=extensions-build /tmp/deepgemm/dist /tmp/deepgemm/dist
|
||||
COPY --from=extensions-build /tmp/ep_kernels_workspace/dist /tmp/ep_kernels_workspace/dist
|
||||
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
@@ -693,15 +677,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
uv pip list
|
||||
|
||||
# Install deepgemm wheel that has been built in the `build` stage
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=build,source=/tmp/deepgemm/dist,target=/tmp/deepgemm/dist,ro \
|
||||
sh -c 'if ls /tmp/deepgemm/dist/*.whl >/dev/null 2>&1; then \
|
||||
uv pip install --system /tmp/deepgemm/dist/*.whl; \
|
||||
else \
|
||||
echo "No DeepGEMM wheels to install; skipping."; \
|
||||
fi'
|
||||
|
||||
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
|
||||
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
|
||||
@@ -52,9 +52,6 @@
|
||||
"vllm_target_device": {
|
||||
"default": "cuda"
|
||||
},
|
||||
"DEEPGEMM_GIT_REF": {
|
||||
"default": "477618cd51baffca09c4b0b87e97c03fe827ef03"
|
||||
},
|
||||
"DEEPEP_COMMIT_HASH": {
|
||||
"default": "73b6ea4"
|
||||
},
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 325 KiB After Width: | Height: | Size: 315 KiB |
29
setup.py
29
setup.py
@@ -379,6 +379,20 @@ class cmake_build_ext(build_ext):
|
||||
dirs_exist_ok=True,
|
||||
)
|
||||
|
||||
if _is_cuda():
|
||||
# copy vendored deep_gemm package from build_lib to source tree
|
||||
# for editable installs
|
||||
deep_gemm_build = os.path.join(
|
||||
self.build_lib, "vllm", "third_party", "deep_gemm"
|
||||
)
|
||||
if os.path.exists(deep_gemm_build):
|
||||
print(f"Copying {deep_gemm_build} to vllm/third_party/deep_gemm")
|
||||
shutil.copytree(
|
||||
deep_gemm_build,
|
||||
"vllm/third_party/deep_gemm",
|
||||
dirs_exist_ok=True,
|
||||
)
|
||||
|
||||
|
||||
class precompiled_build_ext(build_ext):
|
||||
"""Disables extension building when using precompiled binaries."""
|
||||
@@ -685,6 +699,8 @@ class precompiled_wheel_utils:
|
||||
flashmla_regex = re.compile(
|
||||
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||
)
|
||||
# DeepGEMM: extract all files (.py, .so, .cuh, .h, .hpp, etc.)
|
||||
deep_gemm_regex = re.compile(r"vllm/third_party/deep_gemm/.*")
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||
)
|
||||
@@ -699,6 +715,9 @@ class precompiled_wheel_utils:
|
||||
file_members += list(
|
||||
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
|
||||
)
|
||||
file_members += list(
|
||||
filter(lambda x: deep_gemm_regex.match(x.filename), wheel.filelist)
|
||||
)
|
||||
|
||||
for file in file_members:
|
||||
print(f"[extract] {file.filename}")
|
||||
@@ -987,6 +1006,12 @@ if _is_cuda():
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||
)
|
||||
if envs.VLLM_USE_PRECOMPILED or (
|
||||
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.3")
|
||||
):
|
||||
# DeepGEMM requires CUDA 12.3+ (SM90/SM100)
|
||||
# Optional since it won't build on unsupported architectures
|
||||
ext_modules.append(CMakeExtension(name="vllm._deep_gemm_C", optional=True))
|
||||
|
||||
if _is_cpu():
|
||||
import platform
|
||||
@@ -1014,6 +1039,10 @@ package_data = {
|
||||
"entrypoints/serve/instrumentator/static/*.js",
|
||||
"entrypoints/serve/instrumentator/static/*.css",
|
||||
"distributed/kv_transfer/kv_connector/v1/hf3fs/utils/*.cpp",
|
||||
# DeepGEMM JIT include headers (vendored via cmake)
|
||||
"third_party/deep_gemm/include/**/*.cuh",
|
||||
"third_party/deep_gemm/include/**/*.h",
|
||||
"third_party/deep_gemm/include/**/*.hpp",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_fp8_min_max,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
has_deep_gemm,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -256,8 +260,6 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
|
||||
and current_platform.has_device_capability(100)
|
||||
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
||||
):
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
_q, _s = ref_with_scale_fmt(
|
||||
E,
|
||||
T,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
set -e
|
||||
|
||||
# Default values
|
||||
# Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake
|
||||
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03"
|
||||
WHEEL_DIR=""
|
||||
|
||||
@@ -138,6 +138,41 @@ _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _import_deep_gemm():
|
||||
"""Import the deep_gemm module.
|
||||
|
||||
Prefers an externally installed ``deep_gemm`` package (so users can
|
||||
pin a specific version), then falls back to the vendored copy bundled
|
||||
in the vLLM wheel.
|
||||
|
||||
Returns ``None`` when neither source is usable.
|
||||
"""
|
||||
# 1. Try the external (pip-installed) package first.
|
||||
try:
|
||||
module = importlib.import_module("deep_gemm")
|
||||
logger.debug_once("Imported deep_gemm module from site-packages")
|
||||
return module
|
||||
except ImportError:
|
||||
logger.debug_once(
|
||||
"deep_gemm not found in site-packages, "
|
||||
"trying vendored vllm.third_party.deep_gemm"
|
||||
)
|
||||
|
||||
# 2. Fall back to the vendored copy bundled in the vLLM wheel.
|
||||
try:
|
||||
module = importlib.import_module("vllm.third_party.deep_gemm")
|
||||
logger.debug_once("Imported deep_gemm module from vllm.third_party.deep_gemm")
|
||||
return module
|
||||
except ImportError:
|
||||
logger.debug_once("Vendored deep_gemm not found either")
|
||||
except Exception as e:
|
||||
# The vendored module may raise RuntimeError during _C.init()
|
||||
# if JIT include files are missing (e.g. incomplete wheel).
|
||||
logger.warning_once("Failed to import vendored deep_gemm: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
@@ -169,7 +204,9 @@ def _lazy_init() -> None:
|
||||
envs.VLLM_CACHE_ROOT, "deep_gemm"
|
||||
)
|
||||
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
_dg = _import_deep_gemm()
|
||||
if _dg is None:
|
||||
return
|
||||
|
||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
@@ -193,8 +230,18 @@ def _lazy_init() -> None:
|
||||
|
||||
def get_num_sms() -> int:
|
||||
_lazy_init()
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
return int(_dg.get_num_sms())
|
||||
dg = _import_deep_gemm()
|
||||
if dg is None:
|
||||
raise RuntimeError("DeepGEMM is not available")
|
||||
return int(dg.get_num_sms())
|
||||
|
||||
|
||||
def set_num_sms(num_sms: int) -> None:
|
||||
_lazy_init()
|
||||
dg = _import_deep_gemm()
|
||||
if dg is None:
|
||||
raise RuntimeError("DeepGEMM is not available")
|
||||
dg.set_num_sms(num_sms)
|
||||
|
||||
|
||||
@functools.cache
|
||||
@@ -446,6 +493,7 @@ __all__ = [
|
||||
"is_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"set_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
"get_mk_alignment_for_contiguous_layout",
|
||||
|
||||
@@ -408,8 +408,13 @@ def has_deep_ep() -> bool:
|
||||
|
||||
|
||||
def has_deep_gemm() -> bool:
|
||||
"""Whether the optional `deep_gemm` package is available."""
|
||||
return _has_module("deep_gemm")
|
||||
"""Whether the optional `deep_gemm` package is available.
|
||||
|
||||
Prefers an externally installed ``deep_gemm`` package (so users can
|
||||
override with a newer version), then falls back to the vendored copy
|
||||
bundled in the vLLM wheel.
|
||||
"""
|
||||
return _has_module("deep_gemm") or _has_module("vllm.third_party.deep_gemm")
|
||||
|
||||
|
||||
def has_nixl_ep() -> bool:
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.deep_gemm import set_num_sms as deep_gemm_set_num_sms
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
@@ -158,9 +159,7 @@ class UBatchWrapper:
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
set_compute_sms = lambda sms: deep_gemm_set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(
|
||||
comm_sms=comm_sms,
|
||||
|
||||
Reference in New Issue
Block a user