[ROCm][Perf] Enable sparse_mla's cudagraph on ROCm platform (#35719)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -151,7 +151,9 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
|||||||
class ROCMAiterMLASparseMetadataBuilder(
|
class ROCMAiterMLASparseMetadataBuilder(
|
||||||
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
|
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
|
||||||
):
|
):
|
||||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||||
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -327,9 +327,6 @@ def rocm_fp8_paged_mqa_logits(
|
|||||||
aiter_paged_mqa_logits_module = None
|
aiter_paged_mqa_logits_module = None
|
||||||
if rocm_aiter_ops.is_enabled():
|
if rocm_aiter_ops.is_enabled():
|
||||||
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
||||||
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
|
|
||||||
# update aiter to the fix PR.
|
|
||||||
aiter_paged_mqa_logits_module = None
|
|
||||||
|
|
||||||
if aiter_paged_mqa_logits_module is not None:
|
if aiter_paged_mqa_logits_module is not None:
|
||||||
deepgemm_fp8_paged_mqa_logits_stage1 = (
|
deepgemm_fp8_paged_mqa_logits_stage1 = (
|
||||||
|
|||||||
Reference in New Issue
Block a user