[ROCm][Perf] Enable sparse_mla's cudagraph on ROCm platform (#35719)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2026-03-11 00:14:35 +08:00
committed by GitHub
parent 9095cbbfb6
commit 82f3f30e26
2 changed files with 3 additions and 4 deletions

View File

@@ -151,7 +151,9 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
class ROCMAiterMLASparseMetadataBuilder(
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
def __init__(
self,

View File

@@ -327,9 +327,6 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
# update aiter to the fix PR.
aiter_paged_mqa_logits_module = None
if aiter_paged_mqa_logits_module is not None:
deepgemm_fp8_paged_mqa_logits_stage1 = (