[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
@@ -35,7 +35,7 @@ def forward_attention(
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
seqlen_k: int,
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
spec_token_tree: str | None = None,
|
||||
num_spec_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
@@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None:
|
||||
block_table=block_table,
|
||||
slot_mapping=tree_slot_mapping,
|
||||
seqlen_k=seqlen_k,
|
||||
backend=_Backend.TREE_ATTN,
|
||||
backend=AttentionBackendEnum.TREE_ATTN,
|
||||
spec_token_tree=spec_token_tree,
|
||||
num_spec_tokens=tree_size_q - 1,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
@@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
|
||||
block_table=block_table,
|
||||
slot_mapping=branch_slot_mapping,
|
||||
seqlen_k=sequence_position + q_len,
|
||||
backend=_Backend.FLASH_ATTN,
|
||||
backend=AttentionBackendEnum.FLASH_ATTN,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Compare the outputs.
|
||||
|
||||
Reference in New Issue
Block a user