[Kernel] Support CUDA Graphs in 3D Triton Attention Kernel (#28306)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com> Co-authored-by: Thomas Parnell <tom.parnell@gmail.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# constants
|
||||
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
|
||||
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
@@ -54,6 +60,12 @@ class TritonAttentionMetadata:
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
seq_threshold_3D: int
|
||||
num_par_softmax_segments: int
|
||||
softmax_segm_output: torch.Tensor
|
||||
softmax_segm_max: torch.Tensor
|
||||
softmax_segm_expsum: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
@@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
# Check if CUDA Graphs are enabled for decode
|
||||
self.decode_cudagraph_enabled = (
|
||||
self.vllm_config.compilation_config.cudagraph_mode
|
||||
in (
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
CUDAGraphMode.FULL_DECODE_ONLY,
|
||||
CUDAGraphMode.FULL,
|
||||
)
|
||||
)
|
||||
|
||||
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
|
||||
# A lower bound for num_q_blocks is the number of sequences.
|
||||
# To ensure the minimum launch grid size is achieved, the number of sequences
|
||||
# must be at least equal to the threshold below.
|
||||
# If this threshold is not reached (i.e., the batch size is not large enough),
|
||||
# the 3D kernel will be selected instead.
|
||||
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
|
||||
|
||||
# Modify the threshold if needed.
|
||||
if self.decode_cudagraph_enabled:
|
||||
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
|
||||
|
||||
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
|
||||
# as threshold. This ensures that each captured graph covers the
|
||||
# correct execution path.
|
||||
self.seq_threshold_3D = min(
|
||||
capture_sizes,
|
||||
key=lambda x: abs(x - self.seq_threshold_3D),
|
||||
)
|
||||
|
||||
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
|
||||
headdim_padded = next_power_of_2(self.headdim)
|
||||
self.softmax_segm_output = torch.empty(
|
||||
(
|
||||
self.seq_threshold_3D,
|
||||
self.num_heads_q,
|
||||
self.num_par_softmax_segments,
|
||||
headdim_padded,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.softmax_segm_max = torch.empty(
|
||||
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.softmax_segm_expsum = torch.empty(
|
||||
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
@@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
seq_threshold_3D=self.seq_threshold_3D,
|
||||
num_par_softmax_segments=self.num_par_softmax_segments,
|
||||
softmax_segm_output=self.softmax_segm_output,
|
||||
softmax_segm_max=self.softmax_segm_max,
|
||||
softmax_segm_expsum=self.softmax_segm_expsum,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -349,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
seq_threshold_3D = attn_metadata.seq_threshold_3D
|
||||
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
|
||||
softmax_segm_output = attn_metadata.softmax_segm_output
|
||||
softmax_segm_max = attn_metadata.softmax_segm_max
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
|
||||
unified_attention(
|
||||
@@ -369,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user