[v1] Add encoder-only/cross attention support to Triton Attention backend (#31406)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-06 00:00:23 +08:00
committed by GitHub
parent 911d38ed99
commit 6aa5b18e1d
6 changed files with 627 additions and 14 deletions

View File

@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig
from vllm.platforms import current_platform
@@ -83,6 +84,13 @@ BATCH_SPECS = {
),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
# encoder-only
"small_encoder_prefill": BatchSpec(
seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]
),
"medium_encoder_prefill": BatchSpec(
seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048]
),
}
@@ -209,6 +217,7 @@ def run_attention_backend(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: AttentionType = AttentionType.DECODER,
sliding_window: int | None = None,
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
@@ -276,6 +285,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=sliding_window,
attn_type=attn_type,
kv_cache_dtype="auto",
)
@@ -299,6 +309,7 @@ def _test_backend_correctness(
backend_to_test: list[AttentionBackendEnum | str],
mask_mod,
*,
attn_type: AttentionType = AttentionType.DECODER,
block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
@@ -436,6 +447,9 @@ def _test_backend_correctness(
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
if attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata.causal = False
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
@@ -491,6 +505,7 @@ def _test_backend_correctness(
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
attn_type=attn_type,
)
finally:
if reset_kv_cache_layout:
@@ -676,3 +691,45 @@ def test_sliding_window_backend_correctness(
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_encoder_prefill",
"medium_encoder_prefill",
],
)
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_sliding_window_encoder_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention."""
def bidi_sliding_window_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
sliding_window: int,
):
return torch.abs(q_idx + context_len - kv_idx) < sliding_window
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(
bidi_sliding_window_mask_mod, sliding_window=sliding_window
)
_test_backend_correctness(
batch_spec,
model,
SLIDING_WINDOW_BACKENDS_TO_TEST,
sliding_window_mask_mod_fn,
attn_type=AttentionType.ENCODER_ONLY,
tensor_parallel_size=tensor_parallel_size,
)