diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index b1d503ca4..fba59f745 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -151,7 +151,9 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + _cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) def __init__( self, diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 1b6e6596d..878ae3aac 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -327,9 +327,6 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module = None if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() - # FIXME(ganyi): Temporarily disable the aiter path until nightly docker - # update aiter to the fix PR. - aiter_paged_mqa_logits_module = None if aiter_paged_mqa_logits_module is not None: deepgemm_fp8_paged_mqa_logits_stage1 = (