[v1] Add encoder-only/cross attention support to Triton Attention backend (#31406)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
@@ -83,6 +84,13 @@ BATCH_SPECS = {
|
||||
),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
# encoder-only
|
||||
"small_encoder_prefill": BatchSpec(
|
||||
seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]
|
||||
),
|
||||
"medium_encoder_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -209,6 +217,7 @@ def run_attention_backend(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
sliding_window: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
@@ -276,6 +285,7 @@ def run_attention_backend(
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
@@ -299,6 +309,7 @@ def _test_backend_correctness(
|
||||
backend_to_test: list[AttentionBackendEnum | str],
|
||||
mask_mod,
|
||||
*,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
block_size: int = 16,
|
||||
atol: float = 1e-2,
|
||||
rtol: float = 1e-2,
|
||||
@@ -436,6 +447,9 @@ def _test_backend_correctness(
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
# For encoder-only, all tokens are prefill tokens
|
||||
common_attn_metadata.causal = False
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
@@ -491,6 +505,7 @@ def _test_backend_correctness(
|
||||
value_vllm,
|
||||
kv_cache_for_backend,
|
||||
sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
finally:
|
||||
if reset_kv_cache_layout:
|
||||
@@ -676,3 +691,45 @@ def test_sliding_window_backend_correctness(
|
||||
block_size=128,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_encoder_prefill",
|
||||
"medium_encoder_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
|
||||
def test_sliding_window_encoder_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
|
||||
def bidi_sliding_window_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
*,
|
||||
context_len: int,
|
||||
sliding_window: int,
|
||||
):
|
||||
return torch.abs(q_idx + context_len - kv_idx) < sliding_window
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
sliding_window_mask_mod_fn = partial(
|
||||
bidi_sliding_window_mask_mod, sliding_window=sliding_window
|
||||
)
|
||||
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST,
|
||||
sliding_window_mask_mod_fn,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user