Revert "[Attention][FA3] Update FA3 to include new swizzle optimization" (#33841)
This commit is contained in:
@@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff
|
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
|||||||
@@ -308,15 +308,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
)
|
)
|
||||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
|
||||||
|
|
||||||
if self.use_full_cuda_graph and self.aot_schedule:
|
if self.use_full_cuda_graph and self.aot_schedule:
|
||||||
# Times 4 due to:
|
|
||||||
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
|
|
||||||
# For some tests max_cudagraph_size > max_num_seqs,
|
|
||||||
# so we need to use the larger one.
|
|
||||||
self.scheduler_metadata = torch.zeros(
|
self.scheduler_metadata = torch.zeros(
|
||||||
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
|
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -127,15 +127,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
)
|
)
|
||||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
|
||||||
|
|
||||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||||
# Times 4 due to:
|
|
||||||
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
|
|
||||||
# For some tests max_cudagraph_size > max_num_seqs,
|
|
||||||
# so we need to use the larger one.
|
|
||||||
self.scheduler_metadata = torch.zeros(
|
self.scheduler_metadata = torch.zeros(
|
||||||
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
|
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user