[bugfix] fix early import of flash attention (#12959)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-09 00:06:56 +08:00
committed by GitHub
parent 913df14da3
commit fe743b798d
4 changed files with 20 additions and 19 deletions

View File

@@ -10,7 +10,7 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
def forward(
self,
@@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output
@@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=VLLM_FLASH_ATTN_VERSION,
fa_version=self.vllm_flash_attn_version,
)
return output