[UX] Integrate DeepGEMM into vLLM wheel via CMake (#37980)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -12,6 +12,9 @@ vllm/third_party/triton_kernels/*
|
|||||||
# FlashMLA interface copied from source
|
# FlashMLA interface copied from source
|
||||||
vllm/third_party/flashmla/flash_mla_interface.py
|
vllm/third_party/flashmla/flash_mla_interface.py
|
||||||
|
|
||||||
|
# DeepGEMM vendored package built from source
|
||||||
|
vllm/third_party/deep_gemm/
|
||||||
|
|
||||||
# triton jit
|
# triton jit
|
||||||
.triton
|
.triton
|
||||||
|
|
||||||
|
|||||||
@@ -1222,6 +1222,7 @@ endif()
|
|||||||
|
|
||||||
# For CUDA we also build and ship some external projects.
|
# For CUDA we also build and ship some external projects.
|
||||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
include(cmake/external_projects/deepgemm.cmake)
|
||||||
include(cmake/external_projects/flashmla.cmake)
|
include(cmake/external_projects/flashmla.cmake)
|
||||||
include(cmake/external_projects/qutlass.cmake)
|
include(cmake/external_projects/qutlass.cmake)
|
||||||
|
|
||||||
|
|||||||
151
cmake/external_projects/deepgemm.cmake
Normal file
151
cmake/external_projects/deepgemm.cmake
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
# If DEEPGEMM_SRC_DIR is set, DeepGEMM is built from that directory
|
||||||
|
# instead of downloading.
|
||||||
|
# It can be set as an environment variable or passed as a cmake argument.
|
||||||
|
# The environment variable takes precedence.
|
||||||
|
if (DEFINED ENV{DEEPGEMM_SRC_DIR})
|
||||||
|
set(DEEPGEMM_SRC_DIR $ENV{DEEPGEMM_SRC_DIR})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(DEEPGEMM_SRC_DIR)
|
||||||
|
FetchContent_Declare(
|
||||||
|
deepgemm
|
||||||
|
SOURCE_DIR ${DEEPGEMM_SRC_DIR}
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
# This ref should be kept in sync with tools/install_deepgemm.sh
|
||||||
|
FetchContent_Declare(
|
||||||
|
deepgemm
|
||||||
|
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git
|
||||||
|
GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03
|
||||||
|
GIT_SUBMODULES "third-party/cutlass" "third-party/fmt"
|
||||||
|
GIT_PROGRESS TRUE
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Use FetchContent_Populate (not MakeAvailable) to avoid processing
|
||||||
|
# DeepGEMM's own CMakeLists.txt which has incompatible find_package calls.
|
||||||
|
FetchContent_GetProperties(deepgemm)
|
||||||
|
if(NOT deepgemm_POPULATED)
|
||||||
|
FetchContent_Populate(deepgemm)
|
||||||
|
endif()
|
||||||
|
message(STATUS "DeepGEMM is available at ${deepgemm_SOURCE_DIR}")
|
||||||
|
|
||||||
|
# DeepGEMM requires CUDA 12.3+ for SM90, 12.9+ for SM100
|
||||||
|
set(DEEPGEMM_SUPPORT_ARCHS)
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
|
||||||
|
list(APPEND DEEPGEMM_SUPPORT_ARCHS "9.0a")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
|
||||||
|
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0f")
|
||||||
|
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||||
|
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0a")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
cuda_archs_loose_intersection(DEEPGEMM_ARCHS
|
||||||
|
"${DEEPGEMM_SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||||
|
|
||||||
|
if(DEEPGEMM_ARCHS)
|
||||||
|
message(STATUS "DeepGEMM CUDA architectures: ${DEEPGEMM_ARCHS}")
|
||||||
|
|
||||||
|
find_package(CUDAToolkit REQUIRED)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Build the _C pybind11 extension from DeepGEMM's C++ source.
|
||||||
|
# This is a CXX-only module — CUDA kernels are JIT-compiled at runtime.
|
||||||
|
#
|
||||||
|
Python_add_library(_deep_gemm_C MODULE WITH_SOABI
|
||||||
|
"${deepgemm_SOURCE_DIR}/csrc/python_api.cpp")
|
||||||
|
|
||||||
|
# The pybind11 module name must be _C to match DeepGEMM's Python imports.
|
||||||
|
set_target_properties(_deep_gemm_C PROPERTIES OUTPUT_NAME "_C")
|
||||||
|
|
||||||
|
target_compile_definitions(_deep_gemm_C PRIVATE
|
||||||
|
"-DTORCH_EXTENSION_NAME=_C")
|
||||||
|
|
||||||
|
target_include_directories(_deep_gemm_C PRIVATE
|
||||||
|
"${deepgemm_SOURCE_DIR}/csrc"
|
||||||
|
"${deepgemm_SOURCE_DIR}/deep_gemm/include"
|
||||||
|
"${deepgemm_SOURCE_DIR}/third-party/cutlass/include"
|
||||||
|
"${deepgemm_SOURCE_DIR}/third-party/cutlass/tools/util/include"
|
||||||
|
"${deepgemm_SOURCE_DIR}/third-party/fmt/include")
|
||||||
|
|
||||||
|
target_compile_options(_deep_gemm_C PRIVATE
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:-std=c++17>
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:-O3>
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:-Wno-psabi>
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:-Wno-deprecated-declarations>)
|
||||||
|
|
||||||
|
# torch_python is required because DeepGEMM uses pybind11 type casters
|
||||||
|
# for at::Tensor (via PYBIND11_MODULE), unlike vLLM's own extensions which
|
||||||
|
# use torch::Library custom ops.
|
||||||
|
find_library(TORCH_PYTHON_LIBRARY torch_python
|
||||||
|
PATHS "${TORCH_INSTALL_PREFIX}/lib"
|
||||||
|
REQUIRED)
|
||||||
|
|
||||||
|
target_link_libraries(_deep_gemm_C PRIVATE
|
||||||
|
torch ${TORCH_LIBRARIES} "${TORCH_PYTHON_LIBRARY}"
|
||||||
|
CUDA::cudart CUDA::nvrtc)
|
||||||
|
|
||||||
|
# Install the shared library into the vendored package directory
|
||||||
|
install(TARGETS _deep_gemm_C
|
||||||
|
LIBRARY DESTINATION vllm/third_party/deep_gemm
|
||||||
|
COMPONENT _deep_gemm_C)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Vendor DeepGEMM Python package files
|
||||||
|
#
|
||||||
|
install(FILES
|
||||||
|
"${deepgemm_SOURCE_DIR}/deep_gemm/__init__.py"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm
|
||||||
|
COMPONENT _deep_gemm_C)
|
||||||
|
|
||||||
|
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/utils/"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm/utils
|
||||||
|
COMPONENT _deep_gemm_C
|
||||||
|
FILES_MATCHING PATTERN "*.py")
|
||||||
|
|
||||||
|
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/testing/"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm/testing
|
||||||
|
COMPONENT _deep_gemm_C
|
||||||
|
FILES_MATCHING PATTERN "*.py")
|
||||||
|
|
||||||
|
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/legacy/"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm/legacy
|
||||||
|
COMPONENT _deep_gemm_C
|
||||||
|
FILES_MATCHING PATTERN "*.py")
|
||||||
|
|
||||||
|
# Generate envs.py (normally generated by DeepGEMM's setup.py build step)
|
||||||
|
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||||
|
"# Pre-installed environment variables\npersistent_envs = dict()\n")
|
||||||
|
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm
|
||||||
|
RENAME envs.py
|
||||||
|
COMPONENT _deep_gemm_C)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Install include files needed for JIT compilation at runtime.
|
||||||
|
# The JIT compiler finds these relative to the package directory.
|
||||||
|
#
|
||||||
|
|
||||||
|
# DeepGEMM's own CUDA headers
|
||||||
|
install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/include/"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm/include
|
||||||
|
COMPONENT _deep_gemm_C)
|
||||||
|
|
||||||
|
# CUTLASS and CuTe headers (vendored for JIT, separate from vLLM's CUTLASS)
|
||||||
|
install(DIRECTORY "${deepgemm_SOURCE_DIR}/third-party/cutlass/include/"
|
||||||
|
DESTINATION vllm/third_party/deep_gemm/include
|
||||||
|
COMPONENT _deep_gemm_C)
|
||||||
|
|
||||||
|
else()
|
||||||
|
message(STATUS "DeepGEMM will not compile: "
|
||||||
|
"unsupported CUDA architecture ${CUDA_ARCHS}")
|
||||||
|
# Create empty target so setup.py doesn't fail on unsupported systems
|
||||||
|
add_custom_target(_deep_gemm_C)
|
||||||
|
endif()
|
||||||
@@ -315,7 +315,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
|||||||
#################### CSRC BUILD IMAGE ####################
|
#################### CSRC BUILD IMAGE ####################
|
||||||
|
|
||||||
#################### EXTENSIONS BUILD IMAGE ####################
|
#################### EXTENSIONS BUILD IMAGE ####################
|
||||||
# Build DeepGEMM, DeepEP - runs in PARALLEL with csrc-build
|
# Build DeepEP - runs in PARALLEL with csrc-build
|
||||||
# This stage is independent and doesn't affect csrc cache
|
# This stage is independent and doesn't affect csrc cache
|
||||||
FROM base AS extensions-build
|
FROM base AS extensions-build
|
||||||
ARG CUDA_VERSION
|
ARG CUDA_VERSION
|
||||||
@@ -327,21 +327,6 @@ ENV UV_LINK_MODE=copy
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
# Build DeepGEMM wheel
|
|
||||||
# Default moved here from tools/install_deepgemm.sh for centralized version management
|
|
||||||
ARG DEEPGEMM_GIT_REF=477618cd51baffca09c4b0b87e97c03fe827ef03
|
|
||||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
mkdir -p /tmp/deepgemm/dist && \
|
|
||||||
VLLM_DOCKER_BUILD_CONTEXT=1 TORCH_CUDA_ARCH_LIST="9.0a 10.0a" /tmp/install_deepgemm.sh \
|
|
||||||
--cuda-version "${CUDA_VERSION}" \
|
|
||||||
${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} \
|
|
||||||
--wheel-dir /tmp/deepgemm/dist || \
|
|
||||||
echo "DeepGEMM build skipped (CUDA version requirement not met)"
|
|
||||||
|
|
||||||
# Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped
|
|
||||||
RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped
|
|
||||||
|
|
||||||
# Build DeepEP wheels
|
# Build DeepEP wheels
|
||||||
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh
|
||||||
# Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management
|
# Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management
|
||||||
@@ -426,7 +411,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
|
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
|
||||||
|
|
||||||
# Copy extension wheels from extensions-build stage for later use
|
# Copy extension wheels from extensions-build stage for later use
|
||||||
COPY --from=extensions-build /tmp/deepgemm/dist /tmp/deepgemm/dist
|
|
||||||
COPY --from=extensions-build /tmp/ep_kernels_workspace/dist /tmp/ep_kernels_workspace/dist
|
COPY --from=extensions-build /tmp/ep_kernels_workspace/dist /tmp/ep_kernels_workspace/dist
|
||||||
|
|
||||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||||
@@ -693,15 +677,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
. /etc/environment && \
|
. /etc/environment && \
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
# Install deepgemm wheel that has been built in the `build` stage
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
--mount=type=bind,from=build,source=/tmp/deepgemm/dist,target=/tmp/deepgemm/dist,ro \
|
|
||||||
sh -c 'if ls /tmp/deepgemm/dist/*.whl >/dev/null 2>&1; then \
|
|
||||||
uv pip install --system /tmp/deepgemm/dist/*.whl; \
|
|
||||||
else \
|
|
||||||
echo "No DeepGEMM wheels to install; skipping."; \
|
|
||||||
fi'
|
|
||||||
|
|
||||||
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
|
# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH
|
||||||
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
|||||||
@@ -52,9 +52,6 @@
|
|||||||
"vllm_target_device": {
|
"vllm_target_device": {
|
||||||
"default": "cuda"
|
"default": "cuda"
|
||||||
},
|
},
|
||||||
"DEEPGEMM_GIT_REF": {
|
|
||||||
"default": "477618cd51baffca09c4b0b87e97c03fe827ef03"
|
|
||||||
},
|
|
||||||
"DEEPEP_COMMIT_HASH": {
|
"DEEPEP_COMMIT_HASH": {
|
||||||
"default": "73b6ea4"
|
"default": "73b6ea4"
|
||||||
},
|
},
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 325 KiB After Width: | Height: | Size: 315 KiB |
29
setup.py
29
setup.py
@@ -379,6 +379,20 @@ class cmake_build_ext(build_ext):
|
|||||||
dirs_exist_ok=True,
|
dirs_exist_ok=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _is_cuda():
|
||||||
|
# copy vendored deep_gemm package from build_lib to source tree
|
||||||
|
# for editable installs
|
||||||
|
deep_gemm_build = os.path.join(
|
||||||
|
self.build_lib, "vllm", "third_party", "deep_gemm"
|
||||||
|
)
|
||||||
|
if os.path.exists(deep_gemm_build):
|
||||||
|
print(f"Copying {deep_gemm_build} to vllm/third_party/deep_gemm")
|
||||||
|
shutil.copytree(
|
||||||
|
deep_gemm_build,
|
||||||
|
"vllm/third_party/deep_gemm",
|
||||||
|
dirs_exist_ok=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class precompiled_build_ext(build_ext):
|
class precompiled_build_ext(build_ext):
|
||||||
"""Disables extension building when using precompiled binaries."""
|
"""Disables extension building when using precompiled binaries."""
|
||||||
@@ -685,6 +699,8 @@ class precompiled_wheel_utils:
|
|||||||
flashmla_regex = re.compile(
|
flashmla_regex = re.compile(
|
||||||
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||||
)
|
)
|
||||||
|
# DeepGEMM: extract all files (.py, .so, .cuh, .h, .hpp, etc.)
|
||||||
|
deep_gemm_regex = re.compile(r"vllm/third_party/deep_gemm/.*")
|
||||||
file_members = list(
|
file_members = list(
|
||||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||||
)
|
)
|
||||||
@@ -699,6 +715,9 @@ class precompiled_wheel_utils:
|
|||||||
file_members += list(
|
file_members += list(
|
||||||
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
|
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
|
||||||
)
|
)
|
||||||
|
file_members += list(
|
||||||
|
filter(lambda x: deep_gemm_regex.match(x.filename), wheel.filelist)
|
||||||
|
)
|
||||||
|
|
||||||
for file in file_members:
|
for file in file_members:
|
||||||
print(f"[extract] {file.filename}")
|
print(f"[extract] {file.filename}")
|
||||||
@@ -987,6 +1006,12 @@ if _is_cuda():
|
|||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
|
||||||
)
|
)
|
||||||
|
if envs.VLLM_USE_PRECOMPILED or (
|
||||||
|
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.3")
|
||||||
|
):
|
||||||
|
# DeepGEMM requires CUDA 12.3+ (SM90/SM100)
|
||||||
|
# Optional since it won't build on unsupported architectures
|
||||||
|
ext_modules.append(CMakeExtension(name="vllm._deep_gemm_C", optional=True))
|
||||||
|
|
||||||
if _is_cpu():
|
if _is_cpu():
|
||||||
import platform
|
import platform
|
||||||
@@ -1014,6 +1039,10 @@ package_data = {
|
|||||||
"entrypoints/serve/instrumentator/static/*.js",
|
"entrypoints/serve/instrumentator/static/*.js",
|
||||||
"entrypoints/serve/instrumentator/static/*.css",
|
"entrypoints/serve/instrumentator/static/*.css",
|
||||||
"distributed/kv_transfer/kv_connector/v1/hf3fs/utils/*.cpp",
|
"distributed/kv_transfer/kv_connector/v1/hf3fs/utils/*.cpp",
|
||||||
|
# DeepGEMM JIT include headers (vendored via cmake)
|
||||||
|
"third_party/deep_gemm/include/**/*.cuh",
|
||||||
|
"third_party/deep_gemm/include/**/*.h",
|
||||||
|
"third_party/deep_gemm/include/**/*.hpp",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
get_fp8_min_max,
|
get_fp8_min_max,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
from vllm.utils.deep_gemm import (
|
||||||
|
DeepGemmQuantScaleFMT,
|
||||||
|
has_deep_gemm,
|
||||||
|
transform_sf_into_required_layout,
|
||||||
|
)
|
||||||
from vllm.utils.math_utils import cdiv, round_up
|
from vllm.utils.math_utils import cdiv, round_up
|
||||||
from vllm.utils.torch_utils import set_random_seed
|
from vllm.utils.torch_utils import set_random_seed
|
||||||
|
|
||||||
@@ -256,8 +260,6 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
|
|||||||
and current_platform.has_device_capability(100)
|
and current_platform.has_device_capability(100)
|
||||||
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
||||||
):
|
):
|
||||||
from deep_gemm import transform_sf_into_required_layout
|
|
||||||
|
|
||||||
_q, _s = ref_with_scale_fmt(
|
_q, _s = ref_with_scale_fmt(
|
||||||
E,
|
E,
|
||||||
T,
|
T,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Default values
|
# Default values
|
||||||
|
# Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake
|
||||||
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||||
DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03"
|
DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03"
|
||||||
WHEEL_DIR=""
|
WHEEL_DIR=""
|
||||||
|
|||||||
@@ -138,6 +138,41 @@ _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
|||||||
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _import_deep_gemm():
|
||||||
|
"""Import the deep_gemm module.
|
||||||
|
|
||||||
|
Prefers an externally installed ``deep_gemm`` package (so users can
|
||||||
|
pin a specific version), then falls back to the vendored copy bundled
|
||||||
|
in the vLLM wheel.
|
||||||
|
|
||||||
|
Returns ``None`` when neither source is usable.
|
||||||
|
"""
|
||||||
|
# 1. Try the external (pip-installed) package first.
|
||||||
|
try:
|
||||||
|
module = importlib.import_module("deep_gemm")
|
||||||
|
logger.debug_once("Imported deep_gemm module from site-packages")
|
||||||
|
return module
|
||||||
|
except ImportError:
|
||||||
|
logger.debug_once(
|
||||||
|
"deep_gemm not found in site-packages, "
|
||||||
|
"trying vendored vllm.third_party.deep_gemm"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Fall back to the vendored copy bundled in the vLLM wheel.
|
||||||
|
try:
|
||||||
|
module = importlib.import_module("vllm.third_party.deep_gemm")
|
||||||
|
logger.debug_once("Imported deep_gemm module from vllm.third_party.deep_gemm")
|
||||||
|
return module
|
||||||
|
except ImportError:
|
||||||
|
logger.debug_once("Vendored deep_gemm not found either")
|
||||||
|
except Exception as e:
|
||||||
|
# The vendored module may raise RuntimeError during _C.init()
|
||||||
|
# if JIT include files are missing (e.g. incomplete wheel).
|
||||||
|
logger.warning_once("Failed to import vendored deep_gemm: %s", e)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _lazy_init() -> None:
|
def _lazy_init() -> None:
|
||||||
"""Import deep_gemm and resolve symbols on first use."""
|
"""Import deep_gemm and resolve symbols on first use."""
|
||||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||||
@@ -169,7 +204,9 @@ def _lazy_init() -> None:
|
|||||||
envs.VLLM_CACHE_ROOT, "deep_gemm"
|
envs.VLLM_CACHE_ROOT, "deep_gemm"
|
||||||
)
|
)
|
||||||
|
|
||||||
_dg = importlib.import_module("deep_gemm")
|
_dg = _import_deep_gemm()
|
||||||
|
if _dg is None:
|
||||||
|
return
|
||||||
|
|
||||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||||
@@ -193,8 +230,18 @@ def _lazy_init() -> None:
|
|||||||
|
|
||||||
def get_num_sms() -> int:
|
def get_num_sms() -> int:
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
_dg = importlib.import_module("deep_gemm")
|
dg = _import_deep_gemm()
|
||||||
return int(_dg.get_num_sms())
|
if dg is None:
|
||||||
|
raise RuntimeError("DeepGEMM is not available")
|
||||||
|
return int(dg.get_num_sms())
|
||||||
|
|
||||||
|
|
||||||
|
def set_num_sms(num_sms: int) -> None:
|
||||||
|
_lazy_init()
|
||||||
|
dg = _import_deep_gemm()
|
||||||
|
if dg is None:
|
||||||
|
raise RuntimeError("DeepGEMM is not available")
|
||||||
|
dg.set_num_sms(num_sms)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
@@ -446,6 +493,7 @@ __all__ = [
|
|||||||
"is_deep_gemm_e8m0_used",
|
"is_deep_gemm_e8m0_used",
|
||||||
"is_deep_gemm_supported",
|
"is_deep_gemm_supported",
|
||||||
"get_num_sms",
|
"get_num_sms",
|
||||||
|
"set_num_sms",
|
||||||
"should_use_deepgemm_for_fp8_linear",
|
"should_use_deepgemm_for_fp8_linear",
|
||||||
"get_col_major_tma_aligned_tensor",
|
"get_col_major_tma_aligned_tensor",
|
||||||
"get_mk_alignment_for_contiguous_layout",
|
"get_mk_alignment_for_contiguous_layout",
|
||||||
|
|||||||
@@ -408,8 +408,13 @@ def has_deep_ep() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def has_deep_gemm() -> bool:
|
def has_deep_gemm() -> bool:
|
||||||
"""Whether the optional `deep_gemm` package is available."""
|
"""Whether the optional `deep_gemm` package is available.
|
||||||
return _has_module("deep_gemm")
|
|
||||||
|
Prefers an externally installed ``deep_gemm`` package (so users can
|
||||||
|
override with a newer version), then falls back to the vendored copy
|
||||||
|
bundled in the vLLM wheel.
|
||||||
|
"""
|
||||||
|
return _has_module("deep_gemm") or _has_module("vllm.third_party.deep_gemm")
|
||||||
|
|
||||||
|
|
||||||
def has_nixl_ep() -> bool:
|
def has_nixl_ep() -> bool:
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.offloader.base import get_offloader
|
from vllm.model_executor.offloader.base import get_offloader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.deep_gemm import set_num_sms as deep_gemm_set_num_sms
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
from vllm.utils.platform_utils import num_compute_units
|
from vllm.utils.platform_utils import num_compute_units
|
||||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||||
@@ -158,9 +159,7 @@ class UBatchWrapper:
|
|||||||
# TODO(lucas): support other kernels besides DeepGEMM
|
# TODO(lucas): support other kernels besides DeepGEMM
|
||||||
set_compute_sms = lambda sms: None
|
set_compute_sms = lambda sms: None
|
||||||
if has_deep_gemm() and comm_sms > 0:
|
if has_deep_gemm() and comm_sms > 0:
|
||||||
import deep_gemm as dg
|
set_compute_sms = lambda sms: deep_gemm_set_num_sms(sms)
|
||||||
|
|
||||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
|
||||||
|
|
||||||
return SMControlContextManager(
|
return SMControlContextManager(
|
||||||
comm_sms=comm_sms,
|
comm_sms=comm_sms,
|
||||||
|
|||||||
Reference in New Issue
Block a user