[Bugfix] Fix FA3 full cuda graph correctness (#19106)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
if get_flash_attn_version() == 3:
|
||||
self.aot_schedule = not compilation_config.full_cuda_graph
|
||||
if not self.aot_schedule:
|
||||
logger.warning(
|
||||
"AOT Schedule is disabled when using full_cuda_graph")
|
||||
else:
|
||||
self.aot_schedule = False
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph and not self.aot_schedule:
|
||||
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
|
||||
"which requires FlashAttention 3.")
|
||||
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
@@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
@@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n].copy_(scheduler_metadata,
|
||||
non_blocking=True)
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
Reference in New Issue
Block a user