[bugfix] fix early import of flash attention (#12959)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user