[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -13,7 +13,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
@@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TRITON_ATTN
)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
@@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree):
proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,

View File

@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
@@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),

View File

@@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -35,7 +35,7 @@ def forward_attention(
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
seqlen_k: int,
backend: _Backend,
backend: AttentionBackendEnum,
spec_token_tree: str | None = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
@@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table,
slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN,
backend=AttentionBackendEnum.TREE_ATTN,
spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
@@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN,
backend=AttentionBackendEnum.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.