[Kernel] Build flash-attn from source (#8245)

This commit is contained in:
Luka Govedič
2024-09-21 02:27:10 -04:00
committed by GitHub
parent 0faab90eb0
commit 71c60491f2
9 changed files with 124 additions and 41 deletions

View File

@@ -6,6 +6,7 @@ import re
import subprocess
import sys
import warnings
from pathlib import Path
from shutil import which
from typing import Dict, List
@@ -152,15 +153,8 @@ class cmake_build_ext(build_ext):
default_cfg = "Debug" if self.debug else "RelWithDebInfo"
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
# where .so files will be written, should be the same for all extensions
# that use the same CMakeLists.txt.
outdir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))
cmake_args = [
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
]
@@ -224,10 +218,12 @@ class cmake_build_ext(build_ext):
os.makedirs(self.build_temp)
targets = []
target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
"vllm_flash_attn.")
# Build all the extensions
for ext in self.extensions:
self.configure(ext)
targets.append(remove_prefix(ext.name, "vllm."))
targets.append(target_name(ext.name))
num_jobs, _ = self.compute_num_jobs()
@@ -240,6 +236,28 @@ class cmake_build_ext(build_ext):
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
# Install the libraries
for ext in self.extensions:
# Install the extension into the proper location
outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()
# Skip if the install directory is the same as the build directory
if outdir == self.build_temp:
continue
# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
prefix = outdir
for i in range(ext.name.count('.')):
prefix = prefix.parent
# prefix here should actually be the same for all components
install_args = [
"cmake", "--install", ".", "--prefix", prefix, "--component",
target_name(ext.name)
]
subprocess.check_call(install_args, cwd=self.build_temp)
def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
@@ -467,6 +485,10 @@ if _is_cuda() or _is_hip():
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))