[Feature] Enable TRITON_ATTN for Batch Invariance (#33688)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
This commit is contained in:
@@ -1003,8 +1003,11 @@ def vllm_is_batch_invariant() -> bool:
|
||||
def override_envs_for_invariance(
|
||||
attention_backend: AttentionBackendEnum | None,
|
||||
):
|
||||
supported_backends = [
|
||||
decode_invariant_backends = [
|
||||
AttentionBackendEnum.FLASH_ATTN, # best supported backend
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
]
|
||||
supported_backends = decode_invariant_backends + [
|
||||
# FlashInfer temporarily disabled due to invariant CTA sizes.
|
||||
# See FlashInfer issue #2424
|
||||
# AttentionBackendEnum.FLASHINFER,
|
||||
@@ -1025,9 +1028,9 @@ def override_envs_for_invariance(
|
||||
"one of the supported backends before enabling batch_invariant."
|
||||
)
|
||||
raise RuntimeError(error)
|
||||
if attention_backend != supported_backends[0]:
|
||||
if attention_backend not in decode_invariant_backends:
|
||||
warning = (
|
||||
"You are using a decode-invariant form of batch invariance. "
|
||||
"You are using a non-decode-invariant form of batch invariance. "
|
||||
"This will not be invariant between prefill and decode."
|
||||
)
|
||||
logger.warning_once(warning, scope="local")
|
||||
|
||||
Reference in New Issue
Block a user