[Bugfix] Pass FA version in MultiHeadAttention (#30575)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer."""
|
"""Attention layer."""
|
||||||
|
|
||||||
|
import functools
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -17,6 +18,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
)
|
)
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
@@ -524,6 +526,14 @@ class MultiHeadAttention(nn.Module):
|
|||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.fa_version = None
|
||||||
|
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||||
|
self.fa_version = get_flash_attn_version()
|
||||||
|
assert self._flash_attn_varlen_func is not None
|
||||||
|
self._flash_attn_varlen_func = functools.partial(
|
||||||
|
self._flash_attn_varlen_func, fa_version=self.fa_version
|
||||||
|
)
|
||||||
|
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user