Revert "[Attention][FA3] Update FA3 to include new swizzle optimization" (#33841)
This commit is contained in:
@@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff
|
||||
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@@ -308,15 +308,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# Times 4 due to:
|
||||
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
|
||||
# For some tests max_cudagraph_size > max_num_seqs,
|
||||
# so we need to use the larger one.
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -127,15 +127,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
# Times 4 due to:
|
||||
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
|
||||
# For some tests max_cudagraph_size > max_num_seqs,
|
||||
# so we need to use the larger one.
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user