Revert "[Attention][FA3] Update FA3 to include new swizzle optimization" (#33841)

This commit is contained in:
Luka Govedič
2026-02-04 22:54:27 -05:00
committed by GitHub
parent fb1270f1f8
commit e3bf79ffa0
3 changed files with 3 additions and 13 deletions

View File

@@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@@ -308,15 +308,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
if self.use_full_cuda_graph and self.aot_schedule:
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
self.scheduler_metadata = torch.zeros(
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)

View File

@@ -127,15 +127,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
if self.use_full_cuda_graph and self.fa_aot_schedule:
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
self.scheduler_metadata = torch.zeros(
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)