[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -13,7 +13,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
@@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TRITON_ATTN
)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
@@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree):
proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,