[Feature] Enable TRITON_ATTN for Batch Invariance (#33688)

Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
This commit is contained in:
Frank Wang
2026-02-03 21:27:34 -08:00
committed by GitHub
parent 5e1e0a0fbd
commit 45f8fd6f97
4 changed files with 13 additions and 4 deletions

View File

@@ -108,6 +108,7 @@ Batch invariance has been tested and verified on the following models:
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
- **Qwen2.5**: `Qwen/Qwen2.5-0.5B-Instruct`, `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct`, `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct`
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
- **GPT-OSS**: `openai/gpt-oss-20b`, `openai/gpt-oss-120b`
Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).

View File

@@ -18,6 +18,7 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_ATTN",
"TRITON_MLA",
]

View File

@@ -1003,8 +1003,11 @@ def vllm_is_batch_invariant() -> bool:
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
supported_backends = [
decode_invariant_backends = [
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.TRITON_ATTN,
]
supported_backends = decode_invariant_backends + [
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# AttentionBackendEnum.FLASHINFER,
@@ -1025,9 +1028,9 @@ def override_envs_for_invariance(
"one of the supported backends before enabling batch_invariant."
)
raise RuntimeError(error)
if attention_backend != supported_backends[0]:
if attention_backend not in decode_invariant_backends:
warning = (
"You are using a decode-invariant form of batch invariance. "
"You are using a non-decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning, scope="local")

View File

@@ -10,10 +10,12 @@
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
float8_info = torch.finfo(current_platform.fp8_dtype())
@@ -972,7 +974,8 @@ def unified_attention(
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
# 3. The number of sequences exceeds the configured threshold, or
# 4. Batch invariance is enabled
if (
seq_threshold_3D is None
or num_par_softmax_segments is None
@@ -981,6 +984,7 @@ def unified_attention(
or softmax_segm_expsum is None
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
or is_batch_invariant
):
kernel_unified_attention_2d[
(