[Bugfix] Fix FA3 full cuda graph correctness (#19106)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-06-03 23:10:15 -07:00
committed by GitHub
parent 41aa578428
commit b124e1085b
4 changed files with 32 additions and 10 deletions

View File

@@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
if get_flash_attn_version() == 3:
self.aot_schedule = not compilation_config.full_cuda_graph
if not self.aot_schedule:
logger.warning(
"AOT Schedule is disabled when using full_cuda_graph")
else:
self.aot_schedule = False
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph and not self.aot_schedule:
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
"which requires FlashAttention 3.")
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
@@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
@@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
max_seq_len=max_seq_len,
causal=True)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n].copy_(scheduler_metadata,
non_blocking=True)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,