[Bugfix][ROCm] Fix Unsupported attention metadata type for speculative decoding in eagle.py (#31714)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -27,7 +27,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder,
|
||||
@@ -167,7 +166,12 @@ class EagleProposer:
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple | None = None
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata
|
||||
|
||||
rocm_types = [
|
||||
TritonAttentionMetadata,
|
||||
RocmAttentionMetadata,
|
||||
]
|
||||
# ROCM_AITER_FA is an optional backend
|
||||
if find_spec(
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
|
||||
|
||||
Reference in New Issue
Block a user