Reapply [Attention][FA3] Update FA3 to include new swizzle optimization (#34043)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +26,8 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
|
||||
class BatchDescriptor(NamedTuple):
|
||||
@dataclass(frozen=True)
|
||||
class BatchDescriptor:
|
||||
"""
|
||||
Batch descriptor for cudagraph dispatching. We should keep the num of
|
||||
items as minimal as possible to properly and uniquely describe the padded
|
||||
@@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple):
|
||||
to be properly captured.
|
||||
"""
|
||||
|
||||
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
||||
"""
|
||||
Return a relaxed version of current batch descriptor that is still compatible
|
||||
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
||||
"""
|
||||
return BatchDescriptor(
|
||||
self.num_tokens,
|
||||
num_reqs=None,
|
||||
uniform=False,
|
||||
has_lora=self.has_lora,
|
||||
num_active_loras=self.num_active_loras,
|
||||
)
|
||||
|
||||
|
||||
def _compute_sp_num_tokens(
|
||||
num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@@ -310,8 +310,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
|
||||
# The +1 is for the tile_count_semaphore (synchronization).
|
||||
# The 4 slots per batch element (num_prepare_batch_vectors) are:
|
||||
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
|
||||
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
|
||||
max_batch_size = max(
|
||||
vllm_config.scheduler_config.max_num_seqs,
|
||||
self.max_cudagraph_size or 0,
|
||||
)
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
1 + round_up(max_batch_size, 4) * 4,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
@@ -129,8 +130,17 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
|
||||
# The +1 is for the tile_count_semaphore (synchronization).
|
||||
# The 4 slots per batch element (num_prepare_batch_vectors) are:
|
||||
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
|
||||
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
|
||||
max_batch_size = max(
|
||||
vllm_config.scheduler_config.max_num_seqs,
|
||||
self.max_cudagraph_size or 0,
|
||||
)
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
1 + round_up(max_batch_size, 4) * 4,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import replace
|
||||
from itertools import product
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
@@ -180,12 +181,14 @@ class CudagraphDispatcher:
|
||||
for bs, num_active_loras in product(
|
||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, False, num_active_loras > 0, num_active_loras
|
||||
).relax_for_mixed_batch_cudagraphs(),
|
||||
batch_desc = self._create_padded_batch_descriptor(
|
||||
bs, False, num_active_loras > 0, num_active_loras
|
||||
)
|
||||
# Only relax for PIECEWISE mode. FULL mode needs exact num_reqs
|
||||
# because FA3's scheduler_metadata computation depends on it.
|
||||
if cudagraph_mode.mixed_mode() == CUDAGraphMode.PIECEWISE:
|
||||
batch_desc = replace(batch_desc, num_reqs=None, uniform=False)
|
||||
self.add_cudagraph_key(cudagraph_mode.mixed_mode(), batch_desc)
|
||||
|
||||
# if decode cudagraph mode is FULL, and we don't already have mixed
|
||||
# mode full cudagraphs then add them here.
|
||||
@@ -264,21 +267,23 @@ class CudagraphDispatcher:
|
||||
batch_desc = self._create_padded_batch_descriptor(
|
||||
num_tokens, uniform_decode, has_lora, effective_num_active_loras
|
||||
)
|
||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||
|
||||
if not disable_full:
|
||||
# check if key exists for full cudagraph
|
||||
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_desc
|
||||
|
||||
# otherwise, check if the relaxed key exists
|
||||
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, relaxed_batch_desc
|
||||
# check if key exists for full cudagraph
|
||||
# For pure FULL mode, keys are registered with uniform=False.
|
||||
batch_desc_to_check = batch_desc
|
||||
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
batch_desc_to_check = replace(batch_desc, uniform=False)
|
||||
if (
|
||||
not disable_full
|
||||
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
|
||||
):
|
||||
return CUDAGraphMode.FULL, batch_desc_to_check
|
||||
|
||||
# also check if the relaxed key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
|
||||
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
|
||||
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
|
||||
|
||||
# finally, just return no cudagraphs and a trivial batch descriptor
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
Reference in New Issue
Block a user