[Feature] Enable TRITON_ATTN for Batch Invariance (#33688)

Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
This commit is contained in:
Frank Wang
2026-02-03 21:27:34 -08:00
committed by GitHub
parent 5e1e0a0fbd
commit 45f8fd6f97
4 changed files with 13 additions and 4 deletions

View File

@@ -1003,8 +1003,11 @@ def vllm_is_batch_invariant() -> bool:
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
supported_backends = [
decode_invariant_backends = [
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.TRITON_ATTN,
]
supported_backends = decode_invariant_backends + [
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# AttentionBackendEnum.FLASHINFER,
@@ -1025,9 +1028,9 @@ def override_envs_for_invariance(
"one of the supported backends before enabling batch_invariant."
)
raise RuntimeError(error)
if attention_backend != supported_backends[0]:
if attention_backend not in decode_invariant_backends:
warning = (
"You are using a decode-invariant form of batch invariance. "
"You are using a non-decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning, scope="local")