[Feature] Full Cuda Graph Support for Cutlass MLA and 6% E2E Throughput Improvement (#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-08-15 02:27:30 -04:00
committed by GitHub
parent b4cef5e6c7
commit 5c3fbfe46b
2 changed files with 88 additions and 2 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
from typing import ClassVar, Optional
import torch
@@ -12,11 +12,19 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
class CutlassMLABackend(MLACommonBackend):
@staticmethod
@@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend):
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
@staticmethod
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
class SM100Workspace: