diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a5554d99f..b985176dc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -214,11 +214,15 @@ class SpecDecodeBaseProposer: # Determine allowed attention backends once during initialization. self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): + from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import ( + ROCMAiterMLASparseMetadata, + ) from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata rocm_types = [ TritonAttentionMetadata, RocmAttentionMetadata, + ROCMAiterMLASparseMetadata, ] # ROCM_AITER_FA is an optional backend # We check is_enabled() here to avoid importing the backend module during