[Kernel] Support CUDA Graphs in 3D Triton Attention Kernel (#28306)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com>
Co-authored-by: Thomas Parnell <tom.parnell@gmail.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
jvlunteren
2025-12-12 16:55:40 +01:00
committed by GitHub
parent 09ad3b76b3
commit 9c0ee995a8
3 changed files with 140 additions and 40 deletions

View File

@@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
@@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
# constants
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
@@ -54,6 +60,12 @@ class TritonAttentionMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
seq_threshold_3D: int
num_par_softmax_segments: int
softmax_segm_output: torch.Tensor
softmax_segm_max: torch.Tensor
softmax_segm_expsum: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
@@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = (
self.vllm_config.compilation_config.cudagraph_mode
in (
CUDAGraphMode.FULL_AND_PIECEWISE,
CUDAGraphMode.FULL_DECODE_ONLY,
CUDAGraphMode.FULL,
)
)
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
# A lower bound for num_q_blocks is the number of sequences.
# To ensure the minimum launch grid size is achieved, the number of sequences
# must be at least equal to the threshold below.
# If this threshold is not reached (i.e., the batch size is not large enough),
# the 3D kernel will be selected instead.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
# Modify the threshold if needed.
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
self.seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = next_power_of_2(self.headdim)
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
headdim_padded,
),
dtype=torch.float32,
device=device,
)
self.softmax_segm_max = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
self.softmax_segm_expsum = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
@@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
seq_threshold_3D=self.seq_threshold_3D,
num_par_softmax_segments=self.num_par_softmax_segments,
softmax_segm_output=self.softmax_segm_output,
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
return attn_metadata
@@ -349,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
unified_attention(
@@ -369,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
)