[BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder(
|
|||||||
|
|
||||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||||
if self.use_full_cuda_graph and not self.aot_schedule:
|
if self.use_full_cuda_graph:
|
||||||
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
|
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
|
||||||
"which requires FlashAttention 3.")
|
# yet. This is because the scheduler and kernel need to always use
|
||||||
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
|
# the same num_splits (which acts as an upper bound with the
|
||||||
dtype=torch.int32,
|
# dynamic split scheduler) which is currently heuristically decided
|
||||||
device=self.runner.device)
|
# by the kernel launching code.
|
||||||
|
self.aot_schedule = False
|
||||||
|
|
||||||
# Sliding window size to be used with the AOT scheduler will be
|
# Sliding window size to be used with the AOT scheduler will be
|
||||||
# populated on first build() call.
|
# populated on first build() call.
|
||||||
@@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
causal=True)
|
causal=True)
|
||||||
|
|
||||||
if self.use_full_cuda_graph:
|
|
||||||
assert scheduler_metadata is not None
|
|
||||||
n = scheduler_metadata.shape[0]
|
|
||||||
self.scheduler_metadata[:n].copy_(scheduler_metadata,
|
|
||||||
non_blocking=True)
|
|
||||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
|
||||||
# metadata to guarantee the correctness. Otherwise, some thread
|
|
||||||
# blocks may use the invalid scheduler metadata and overwrite the
|
|
||||||
# output buffer.
|
|
||||||
self.scheduler_metadata[n:] = 0
|
|
||||||
scheduler_metadata = self.scheduler_metadata[:n]
|
|
||||||
|
|
||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user