[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionImpl
|
||||
from vllm.attention.backends.registry import _Backend, backend_to_class_str
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
@@ -20,7 +20,6 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
@@ -120,15 +119,14 @@ def create_common_attn_metadata(
|
||||
|
||||
|
||||
def try_get_attention_backend(
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
|
||||
"""Try to get the attention backend class, skipping test if not found."""
|
||||
backend_class_str = backend_to_class_str(backend)
|
||||
try:
|
||||
backend_class = resolve_obj_by_qualname(backend_class_str)
|
||||
backend_class = backend.get_class()
|
||||
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
|
||||
except ImportError as e:
|
||||
pytest.skip(f"{backend_class_str} not available: {e}")
|
||||
pytest.skip(f"{backend.name} not available: {e}")
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user