Reapply [Attention][FA3] Update FA3 to include new swizzle optimization (#34043)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-02-11 08:07:56 -07:00
committed by GitHub
parent 1b8756562e
commit c7914d30f9
6 changed files with 60 additions and 44 deletions

View File

@@ -5,7 +5,7 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, NamedTuple
from typing import Any
import torch
@@ -26,7 +26,8 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
@dataclass(frozen=True)
class BatchDescriptor:
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
@@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple):
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)
def _compute_sp_num_tokens(
num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
@@ -310,8 +310,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)

View File

@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
@@ -129,8 +130,17 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
@@ -180,12 +181,14 @@ class CudagraphDispatcher:
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
batch_desc = self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
)
# Only relax for PIECEWISE mode. FULL mode needs exact num_reqs
# because FA3's scheduler_metadata computation depends on it.
if cudagraph_mode.mixed_mode() == CUDAGraphMode.PIECEWISE:
batch_desc = replace(batch_desc, num_reqs=None, uniform=False)
self.add_cudagraph_key(cudagraph_mode.mixed_mode(), batch_desc)
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
@@ -264,21 +267,23 @@ class CudagraphDispatcher:
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc
# check if key exists for full cudagraph
# For pure FULL mode, keys are registered with uniform=False.
batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if (
not disable_full
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
):
return CUDAGraphMode.FULL, batch_desc_to_check
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)