From 82f3f30e266e24b26c46916a8c9daaea7d5e32bd Mon Sep 17 00:00:00 2001 From: Pleaplusone Date: Wed, 11 Mar 2026 00:14:35 +0800 Subject: [PATCH] [ROCm][Perf] Enable `sparse_mla`'s cudagraph on ROCm platform (#35719) Signed-off-by: ganyi --- vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py | 4 +++- vllm/v1/attention/ops/rocm_aiter_mla_sparse.py | 3 --- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index b1d503ca4..fba59f745 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -151,7 +151,9 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + _cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) def __init__( self, diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 1b6e6596d..878ae3aac 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -327,9 +327,6 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module = None if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() - # FIXME(ganyi): Temporarily disable the aiter path until nightly docker - # update aiter to the fix PR. - aiter_paged_mqa_logits_module = None if aiter_paged_mqa_logits_module is not None: deepgemm_fp8_paged_mqa_logits_stage1 = (