From 0f9e7354f508af3fe314cfb709babaaa668f1b04 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 25 Jun 2025 04:39:04 -0400 Subject: [PATCH] [BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057) Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flash_attn.py | 25 +++++++----------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4ad717837..ef65d2ea3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder( self.aot_schedule = (get_flash_attn_version() == 3) self.use_full_cuda_graph = compilation_config.full_cuda_graph - if self.use_full_cuda_graph and not self.aot_schedule: - raise ValueError("Full CUDA graph mode requires AOT scheduling, " - "which requires FlashAttention 3.") - self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1, - dtype=torch.int32, - device=self.runner.device) + if self.use_full_cuda_graph: + # NOTE(lucas): AOT scheduling not supported in full cuda graph mode + # yet. This is because the scheduler and kernel need to always use + # the same num_splits (which acts as an upper bound with the + # dynamic split scheduler) which is currently heuristically decided + # by the kernel launching code. + self.aot_schedule = False # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder( max_seq_len=max_seq_len, causal=True) - if self.use_full_cuda_graph: - assert scheduler_metadata is not None - n = scheduler_metadata.shape[0] - self.scheduler_metadata[:n].copy_(scheduler_metadata, - non_blocking=True) - # NOTE(woosuk): We should zero out the rest of the scheduler - # metadata to guarantee the correctness. Otherwise, some thread - # blocks may use the invalid scheduler metadata and overwrite the - # output buffer. - self.scheduler_metadata[n:] = 0 - scheduler_metadata = self.scheduler_metadata[:n] - attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len,