diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 9414f5af3..c56dcb443 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -87,18 +87,30 @@ endforeach() # add_custom_target(_vllm_fa4_cutedsl_C) -# Copy flash_attn/cute directory (needed for FA4) and transform imports -# The cute directory uses flash_attn.cute imports internally, which we replace -# with vllm.vllm_flash_attn.cute to match our package structure. -install(CODE " - file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\") - foreach(SRC_FILE \${CUTE_PY_FILES}) - file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE}) - set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\") - get_filename_component(DST_DIR \${DST_FILE} DIRECTORY) - file(MAKE_DIRECTORY \${DST_DIR}) - file(READ \${SRC_FILE} FILE_CONTENTS) - string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\") - file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\") - endforeach() -" COMPONENT _vllm_fa4_cutedsl_C) +# Install flash_attn/cute directory (needed for FA4). +# When using a local source dir (VLLM_FLASH_ATTN_SRC_DIR), create a symlink +# so edits to cute-dsl Python files take effect immediately without rebuilding. +# Otherwise, copy files and transform flash_attn.cute imports to +# vllm.vllm_flash_attn.cute to match our package structure. +if(VLLM_FLASH_ATTN_SRC_DIR) + install(CODE " + set(LINK_TARGET \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\") + set(LINK_NAME \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute\") + file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\") + file(REMOVE_RECURSE \"\${LINK_NAME}\") + file(CREATE_LINK \"\${LINK_TARGET}\" \"\${LINK_NAME}\" SYMBOLIC) + " COMPONENT _vllm_fa4_cutedsl_C) +else() + install(CODE " + file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\") + foreach(SRC_FILE \${CUTE_PY_FILES}) + file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE}) + set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\") + get_filename_component(DST_DIR \${DST_FILE} DIRECTORY) + file(MAKE_DIRECTORY \${DST_DIR}) + file(READ \${SRC_FILE} FILE_CONTENTS) + string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\") + file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\") + endforeach() + " COMPONENT _vllm_fa4_cutedsl_C) +endif() diff --git a/vllm/vllm_flash_attn/__init__.py b/vllm/vllm_flash_attn/__init__.py index 3507defab..7dea1f659 100644 --- a/vllm/vllm_flash_attn/__init__.py +++ b/vllm/vllm_flash_attn/__init__.py @@ -1,7 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.vllm_flash_attn.flash_attn_interface import ( +import importlib.machinery +import os +import sys +import types + +# In symlink mode (VLLM_FLASH_ATTN_SRC_DIR), cute/ is a symlink to the real +# source tree and its files use `flash_attn.cute.*` imports (not rewritten). +# Register a virtual `flash_attn` package so those imports resolve. +_cute_dir = os.path.join(os.path.dirname(__file__), "cute") +if os.path.islink(_cute_dir) and "flash_attn" not in sys.modules: + _fa_mod = types.ModuleType("flash_attn") + _fa_mod.__path__ = [os.path.dirname(os.path.realpath(_cute_dir))] + _fa_mod.__package__ = "flash_attn" + _fa_mod.__spec__ = importlib.machinery.ModuleSpec( + "flash_attn", None, is_package=True + ) + _fa_mod.__spec__.submodule_search_locations = _fa_mod.__path__ + sys.modules["flash_attn"] = _fa_mod + +from vllm.vllm_flash_attn.flash_attn_interface import ( # noqa: E402 FA2_AVAILABLE, FA3_AVAILABLE, fa_version_unsupported_reason,