[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -154,9 +154,26 @@ def _get_sliding_window_configs(
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
|
||||
else AttentionCGSupport.ALWAYS
|
||||
# FA3:
|
||||
# Supports full cudagraphs for all cases.
|
||||
#
|
||||
# FA2:
|
||||
# For FA2, a graph is captured with max_query_len=1, (which is what we
|
||||
# capture by default for num_tokens <= max_num_seqs when there is no
|
||||
# spec-decode) then these graphs will not work for mixed prefill-decode
|
||||
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
|
||||
# in FA2.
|
||||
# In summary if we are running with spec decodes the graphs would
|
||||
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
|
||||
# the graphs would not work for mixed prefill-decode; sorta the inverse
|
||||
# of UNIFORM_SINGLE_TOKEN_DECODE.
|
||||
# Theres probably a better way to describe this using `AttentionCGSupport`
|
||||
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
|
||||
# to FULL_AND_PIECEWISE.
|
||||
# TODO(luka, lucas): audit FA2 as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
cudagraph_support = AttentionCGSupport.ALWAYS \
|
||||
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph:
|
||||
if not self.aot_schedule:
|
||||
raise ValueError(
|
||||
"AoT scheduling is required for full cuda graph.")
|
||||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||
if not capture_sizes:
|
||||
raise ValueError(
|
||||
"cudagraph_capture_sizes should not be None when "
|
||||
"full_cuda_graph is True.")
|
||||
self.max_cudagraph_size = max(capture_sizes)
|
||||
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||
|
||||
if self.max_cudagraph_size > 992:
|
||||
# This condition derives from FA3's internal heuristic.
|
||||
# TODO(woosuk): Support larger cudagraph sizes.
|
||||
@@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder(
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=causal)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
# For FA3 + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
@@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder(
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
max_num_splits = 0
|
||||
if (self.use_full_cuda_graph
|
||||
and num_actual_tokens <= self.max_cudagraph_size):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
if num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder(
|
||||
causal=causal)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_attention
|
||||
@@ -183,8 +183,8 @@ class FlashInferMetadata:
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
|
||||
decode_mode() == CUDAGraphMode.FULL
|
||||
if self.enable_cuda_graph:
|
||||
# For full cudagraph capture, one `decode_wrapper` for each batch
|
||||
# size is needed for FlashInfer.
|
||||
@@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
|
||||
@@ -89,8 +89,8 @@ class Mamba2AttentionMetadata:
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder(
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
@@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
assert m.max_query_len == 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
@@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = init_logger(__name__)
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
attn_cudagraph_support: ClassVar[
|
||||
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
|
||||
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
@@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
|
||||
@@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
if vllm_config.compilation_config.full_cuda_graph:
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
|
||||
@@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -58,8 +58,7 @@ class TritonAttentionMetadata:
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum):
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
ALWAYS = 3
|
||||
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
||||
UNIFORM_BATCH = 2
|
||||
"""Cudagraph supported for batches the only contain query lengths that are
|
||||
the same, this can be used for spec-decode
|
||||
i.e. "decodes" are 1 + num_speculative_tokens"""
|
||||
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
||||
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
PURE_DECODE_ONLY = 1
|
||||
"""Cudagraph supported for pure decode, need to run without
|
||||
cudagraph for mixed prefill-decode batches"""
|
||||
ALWAYS = 2
|
||||
"""Cudagraph always supported"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
@@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
"""
|
||||
Can this batch (with given metadata) use CUDA Graphs for attention.
|
||||
"""
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user