[BugFix] Fallback from FA4->FA2 for Batch Invariance (#36059)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
from typing import Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -111,6 +112,16 @@ def get_flash_attn_version(
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
# FA4 currently uses batch-shape-dependent scheduling
|
||||
# heuristics on SM100+, which breaks batch invariance.
|
||||
if vllm_is_batch_invariant() and fa_version == 4:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 4 with batch invariance, "
|
||||
"defaulting to FA version 2.",
|
||||
scope="local",
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# supported head dimensions.
|
||||
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
|
||||
Reference in New Issue
Block a user