[Feature] Enable TRITON_ATTN for Batch Invariance (#33688)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
This commit is contained in:
@@ -108,6 +108,7 @@ Batch invariance has been tested and verified on the following models:
|
||||
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
|
||||
- **Qwen2.5**: `Qwen/Qwen2.5-0.5B-Instruct`, `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct`, `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct`
|
||||
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
|
||||
- **GPT-OSS**: `openai/gpt-oss-20b`, `openai/gpt-oss-120b`
|
||||
|
||||
Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ skip_unsupported = pytest.mark.skipif(
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"TRITON_ATTN",
|
||||
"TRITON_MLA",
|
||||
]
|
||||
|
||||
|
||||
@@ -1003,8 +1003,11 @@ def vllm_is_batch_invariant() -> bool:
|
||||
def override_envs_for_invariance(
|
||||
attention_backend: AttentionBackendEnum | None,
|
||||
):
|
||||
supported_backends = [
|
||||
decode_invariant_backends = [
|
||||
AttentionBackendEnum.FLASH_ATTN, # best supported backend
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
]
|
||||
supported_backends = decode_invariant_backends + [
|
||||
# FlashInfer temporarily disabled due to invariant CTA sizes.
|
||||
# See FlashInfer issue #2424
|
||||
# AttentionBackendEnum.FLASHINFER,
|
||||
@@ -1025,9 +1028,9 @@ def override_envs_for_invariance(
|
||||
"one of the supported backends before enabling batch_invariant."
|
||||
)
|
||||
raise RuntimeError(error)
|
||||
if attention_backend != supported_backends[0]:
|
||||
if attention_backend not in decode_invariant_backends:
|
||||
warning = (
|
||||
"You are using a decode-invariant form of batch invariance. "
|
||||
"You are using a non-decode-invariant form of batch invariance. "
|
||||
"This will not be invariant between prefill and decode."
|
||||
)
|
||||
logger.warning_once(warning, scope="local")
|
||||
|
||||
@@ -10,10 +10,12 @@
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = vllm_is_batch_invariant()
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@@ -972,7 +974,8 @@ def unified_attention(
|
||||
# Launch the 2D kernel if
|
||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||
# 2. The batch includes at least one prefill request, or
|
||||
# 3. The number of sequences exceeds the configured threshold
|
||||
# 3. The number of sequences exceeds the configured threshold, or
|
||||
# 4. Batch invariance is enabled
|
||||
if (
|
||||
seq_threshold_3D is None
|
||||
or num_par_softmax_segments is None
|
||||
@@ -981,6 +984,7 @@ def unified_attention(
|
||||
or softmax_segm_expsum is None
|
||||
or max_seqlen_q > 1
|
||||
or num_seqs > seq_threshold_3D
|
||||
or is_batch_invariant
|
||||
):
|
||||
kernel_unified_attention_2d[
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user