[Kernel] Build flash-attn from source (#8245)
This commit is contained in:
38
setup.py
38
setup.py
@@ -6,6 +6,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -152,15 +153,8 @@ class cmake_build_ext(build_ext):
|
||||
default_cfg = "Debug" if self.debug else "RelWithDebInfo"
|
||||
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
|
||||
|
||||
# where .so files will be written, should be the same for all extensions
|
||||
# that use the same CMakeLists.txt.
|
||||
outdir = os.path.abspath(
|
||||
os.path.dirname(self.get_ext_fullpath(ext.name)))
|
||||
|
||||
cmake_args = [
|
||||
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
|
||||
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
|
||||
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
|
||||
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
|
||||
]
|
||||
|
||||
@@ -224,10 +218,12 @@ class cmake_build_ext(build_ext):
|
||||
os.makedirs(self.build_temp)
|
||||
|
||||
targets = []
|
||||
target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
|
||||
"vllm_flash_attn.")
|
||||
# Build all the extensions
|
||||
for ext in self.extensions:
|
||||
self.configure(ext)
|
||||
targets.append(remove_prefix(ext.name, "vllm."))
|
||||
targets.append(target_name(ext.name))
|
||||
|
||||
num_jobs, _ = self.compute_num_jobs()
|
||||
|
||||
@@ -240,6 +236,28 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
|
||||
|
||||
# Install the libraries
|
||||
for ext in self.extensions:
|
||||
# Install the extension into the proper location
|
||||
outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()
|
||||
|
||||
# Skip if the install directory is the same as the build directory
|
||||
if outdir == self.build_temp:
|
||||
continue
|
||||
|
||||
# CMake appends the extension prefix to the install path,
|
||||
# and outdir already contains that prefix, so we need to remove it.
|
||||
prefix = outdir
|
||||
for i in range(ext.name.count('.')):
|
||||
prefix = prefix.parent
|
||||
|
||||
# prefix here should actually be the same for all components
|
||||
install_args = [
|
||||
"cmake", "--install", ".", "--prefix", prefix, "--component",
|
||||
target_name(ext.name)
|
||||
]
|
||||
subprocess.check_call(install_args, cwd=self.build_temp)
|
||||
|
||||
|
||||
def _no_device() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "empty"
|
||||
@@ -467,6 +485,10 @@ if _is_cuda() or _is_hip():
|
||||
if _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||
|
||||
if _is_cuda():
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user