[Kernel] Support CUDA Graphs in 3D Triton Attention Kernel (#28306)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: jvlunteren <161835099+jvlunteren@users.noreply.github.com> Co-authored-by: Thomas Parnell <tom.parnell@gmail.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.math_utils import next_power_of_2
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2)]
|
NUM_HEADS = [(4, 4), (8, 2)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
@@ -22,6 +23,10 @@ QDTYPES = (
|
|||||||
# one value small enough to test the schema op check
|
# one value small enough to test the schema op check
|
||||||
NUM_BLOCKS = [32768, 2048]
|
NUM_BLOCKS = [32768, 2048]
|
||||||
|
|
||||||
|
# 0: use 2D kernel for decode
|
||||||
|
# 8: use 3D kernel for decode
|
||||||
|
SEQ_THRESHOLD_3D_VALUES = [0, 8]
|
||||||
|
|
||||||
|
|
||||||
def ref_paged_attn(
|
def ref_paged_attn(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -92,6 +97,7 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||||
|
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_triton_unified_attn(
|
def test_triton_unified_attn(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
@@ -103,6 +109,7 @@ def test_triton_unified_attn(
|
|||||||
soft_cap: float | None,
|
soft_cap: float | None,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
q_dtype: torch.dtype | None,
|
q_dtype: torch.dtype | None,
|
||||||
|
seq_threshold_3D: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
@@ -152,6 +159,21 @@ def test_triton_unified_attn(
|
|||||||
k_descale = torch.rand(scale_shape, dtype=torch.float32)
|
k_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||||
v_descale = torch.rand(scale_shape, dtype=torch.float32)
|
v_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||||
|
|
||||||
|
num_par_softmax_segments = 16
|
||||||
|
head_size_padded = next_power_of_2(head_size)
|
||||||
|
softmax_segm_output = torch.empty(
|
||||||
|
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
softmax_segm_max = torch.empty(
|
||||||
|
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
softmax_segm_expsum = torch.empty(
|
||||||
|
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
unified_attention(
|
unified_attention(
|
||||||
q=maybe_quantized_query,
|
q=maybe_quantized_query,
|
||||||
k=maybe_quantized_key_cache,
|
k=maybe_quantized_key_cache,
|
||||||
@@ -169,6 +191,11 @@ def test_triton_unified_attn(
|
|||||||
q_descale=q_descale,
|
q_descale=q_descale,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
seq_threshold_3D=seq_threshold_3D,
|
||||||
|
num_par_softmax_segments=num_par_softmax_segments,
|
||||||
|
softmax_segm_output=softmax_segm_output,
|
||||||
|
softmax_segm_max=softmax_segm_max,
|
||||||
|
softmax_segm_expsum=softmax_segm_expsum,
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_output = ref_paged_attn(
|
ref_output = ref_paged_attn(
|
||||||
|
|||||||
@@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel_unified_attention_3d(
|
def kernel_unified_attention_3d(
|
||||||
segm_output_ptr,
|
segm_output_ptr,
|
||||||
# [num_tokens, num_query_heads, num_segments, head_size]
|
# [num_tokens, num_query_heads, num_segments, head_size_padded]
|
||||||
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||||
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
@@ -749,6 +749,11 @@ def unified_attention(
|
|||||||
q_descale,
|
q_descale,
|
||||||
k_descale,
|
k_descale,
|
||||||
v_descale,
|
v_descale,
|
||||||
|
seq_threshold_3D=None,
|
||||||
|
num_par_softmax_segments=None,
|
||||||
|
softmax_segm_output=None,
|
||||||
|
softmax_segm_max=None,
|
||||||
|
softmax_segm_expsum=None,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
output_scale=None,
|
output_scale=None,
|
||||||
qq_bias=None,
|
qq_bias=None,
|
||||||
@@ -793,8 +798,19 @@ def unified_attention(
|
|||||||
TILE_SIZE_PREFILL = 32
|
TILE_SIZE_PREFILL = 32
|
||||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||||
|
|
||||||
# if batch contains a prefill
|
# Launch the 2D kernel if
|
||||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||||
|
# 2. The batch includes at least one prefill request, or
|
||||||
|
# 3. The number of sequences exceeds the configured threshold
|
||||||
|
if (
|
||||||
|
seq_threshold_3D is None
|
||||||
|
or num_par_softmax_segments is None
|
||||||
|
or softmax_segm_output is None
|
||||||
|
or softmax_segm_max is None
|
||||||
|
or softmax_segm_expsum is None
|
||||||
|
or max_seqlen_q > 1
|
||||||
|
or num_seqs > seq_threshold_3D
|
||||||
|
):
|
||||||
kernel_unified_attention_2d[
|
kernel_unified_attention_2d[
|
||||||
(
|
(
|
||||||
total_num_q_blocks,
|
total_num_q_blocks,
|
||||||
@@ -847,37 +863,12 @@ def unified_attention(
|
|||||||
USE_FP8=output_scale is not None,
|
USE_FP8=output_scale is not None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
kernel_unified_attention_3d[
|
||||||
# value that showed good performance in tests
|
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
|
||||||
NUM_SEGMENTS = 16
|
](
|
||||||
|
segm_output_ptr=softmax_segm_output,
|
||||||
segm_output = torch.empty(
|
segm_max_ptr=softmax_segm_max,
|
||||||
q.shape[0],
|
segm_expsum_ptr=softmax_segm_expsum,
|
||||||
num_query_heads,
|
|
||||||
NUM_SEGMENTS,
|
|
||||||
triton.next_power_of_2(head_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
segm_max = torch.empty(
|
|
||||||
q.shape[0],
|
|
||||||
num_query_heads,
|
|
||||||
NUM_SEGMENTS,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
segm_expsum = torch.empty(
|
|
||||||
q.shape[0],
|
|
||||||
num_query_heads,
|
|
||||||
NUM_SEGMENTS,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
|
||||||
segm_output_ptr=segm_output,
|
|
||||||
segm_max_ptr=segm_max,
|
|
||||||
segm_expsum_ptr=segm_expsum,
|
|
||||||
query_ptr=q,
|
query_ptr=q,
|
||||||
key_cache_ptr=k,
|
key_cache_ptr=k,
|
||||||
value_cache_ptr=v,
|
value_cache_ptr=v,
|
||||||
@@ -917,13 +908,13 @@ def unified_attention(
|
|||||||
BLOCK_Q=BLOCK_Q,
|
BLOCK_Q=BLOCK_Q,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
BLOCK_M=BLOCK_M,
|
BLOCK_M=BLOCK_M,
|
||||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||||
)
|
)
|
||||||
reduce_segments[(q.shape[0], num_query_heads)](
|
reduce_segments[(q.shape[0], num_query_heads)](
|
||||||
output_ptr=out,
|
output_ptr=out,
|
||||||
segm_output_ptr=segm_output,
|
segm_output_ptr=softmax_segm_output,
|
||||||
segm_max_ptr=segm_max,
|
segm_max_ptr=softmax_segm_max,
|
||||||
segm_expsum_ptr=segm_expsum,
|
segm_expsum_ptr=softmax_segm_expsum,
|
||||||
seq_lens_ptr=seqused_k,
|
seq_lens_ptr=seqused_k,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
num_query_heads=num_query_heads,
|
num_query_heads=num_query_heads,
|
||||||
@@ -936,6 +927,6 @@ def unified_attention(
|
|||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
query_start_len_ptr=cu_seqlens_q,
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
BLOCK_Q=BLOCK_Q,
|
BLOCK_Q=BLOCK_Q,
|
||||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||||
USE_FP8=output_scale is not None,
|
USE_FP8=output_scale is not None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
|||||||
triton_reshape_and_cache_flash,
|
triton_reshape_and_cache_flash,
|
||||||
)
|
)
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
|
from vllm.utils.math_utils import next_power_of_2
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
@@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# constants
|
||||||
|
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
|
||||||
|
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TritonAttentionMetadata:
|
class TritonAttentionMetadata:
|
||||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||||
@@ -54,6 +60,12 @@ class TritonAttentionMetadata:
|
|||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
seq_threshold_3D: int
|
||||||
|
num_par_softmax_segments: int
|
||||||
|
softmax_segm_output: torch.Tensor
|
||||||
|
softmax_segm_max: torch.Tensor
|
||||||
|
softmax_segm_expsum: torch.Tensor
|
||||||
|
|
||||||
# For cascade attention.
|
# For cascade attention.
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
common_prefix_len: int
|
common_prefix_len: int
|
||||||
@@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
|||||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||||
self.headdim = model_config.get_head_size()
|
self.headdim = model_config.get_head_size()
|
||||||
|
|
||||||
|
# Check if CUDA Graphs are enabled for decode
|
||||||
|
self.decode_cudagraph_enabled = (
|
||||||
|
self.vllm_config.compilation_config.cudagraph_mode
|
||||||
|
in (
|
||||||
|
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
|
CUDAGraphMode.FULL_DECODE_ONLY,
|
||||||
|
CUDAGraphMode.FULL,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
|
||||||
|
# A lower bound for num_q_blocks is the number of sequences.
|
||||||
|
# To ensure the minimum launch grid size is achieved, the number of sequences
|
||||||
|
# must be at least equal to the threshold below.
|
||||||
|
# If this threshold is not reached (i.e., the batch size is not large enough),
|
||||||
|
# the 3D kernel will be selected instead.
|
||||||
|
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
|
||||||
|
|
||||||
|
# Modify the threshold if needed.
|
||||||
|
if self.decode_cudagraph_enabled:
|
||||||
|
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||||
|
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
|
||||||
|
|
||||||
|
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
|
||||||
|
# as threshold. This ensures that each captured graph covers the
|
||||||
|
# correct execution path.
|
||||||
|
self.seq_threshold_3D = min(
|
||||||
|
capture_sizes,
|
||||||
|
key=lambda x: abs(x - self.seq_threshold_3D),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
|
||||||
|
headdim_padded = next_power_of_2(self.headdim)
|
||||||
|
self.softmax_segm_output = torch.empty(
|
||||||
|
(
|
||||||
|
self.seq_threshold_3D,
|
||||||
|
self.num_heads_q,
|
||||||
|
self.num_par_softmax_segments,
|
||||||
|
headdim_padded,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.softmax_segm_max = torch.empty(
|
||||||
|
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.softmax_segm_expsum = torch.empty(
|
||||||
|
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
) -> TritonAttentionMetadata:
|
) -> TritonAttentionMetadata:
|
||||||
@@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
|||||||
prefix_kv_lens=prefix_kv_lens,
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
|
seq_threshold_3D=self.seq_threshold_3D,
|
||||||
|
num_par_softmax_segments=self.num_par_softmax_segments,
|
||||||
|
softmax_segm_output=self.softmax_segm_output,
|
||||||
|
softmax_segm_max=self.softmax_segm_max,
|
||||||
|
softmax_segm_expsum=self.softmax_segm_expsum,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@@ -349,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
max_seqlen_k = attn_metadata.max_seq_len
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
|
|
||||||
|
seq_threshold_3D = attn_metadata.seq_threshold_3D
|
||||||
|
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
|
||||||
|
softmax_segm_output = attn_metadata.softmax_segm_output
|
||||||
|
softmax_segm_max = attn_metadata.softmax_segm_max
|
||||||
|
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||||
|
|
||||||
unified_attention(
|
unified_attention(
|
||||||
@@ -369,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
q_descale=None, # Not supported
|
q_descale=None, # Not supported
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
seq_threshold_3D=seq_threshold_3D,
|
||||||
|
num_par_softmax_segments=num_par_softmax_segments,
|
||||||
|
softmax_segm_output=softmax_segm_output,
|
||||||
|
softmax_segm_max=softmax_segm_max,
|
||||||
|
softmax_segm_expsum=softmax_segm_expsum,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user