[BugFix] Fallback from FA4->FA2 for Batch Invariance (#36059)

Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
This commit is contained in:
Frank Wang
2026-03-05 11:05:56 -08:00
committed by GitHub
parent f917020983
commit a57c877f18

View File

@@ -4,6 +4,7 @@
from typing import Any
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -111,6 +112,16 @@ def get_flash_attn_version(
)
fa_version = 2
# FA4 currently uses batch-shape-dependent scheduling
# heuristics on SM100+, which breaks batch invariance.
if vllm_is_batch_invariant() and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with batch invariance, "
"defaulting to FA version 2.",
scope="local",
)
fa_version = 2
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959