diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index dbdfd5e81..b51934a3a 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff + GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 232b0b0da..927572531 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -308,15 +308,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size - max_num_seqs = vllm_config.scheduler_config.max_num_seqs if self.use_full_cuda_graph and self.aot_schedule: - # Times 4 due to: - # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 - # For some tests max_cudagraph_size > max_num_seqs, - # so we need to use the larger one. self.scheduler_metadata = torch.zeros( - max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1, + vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, device=self.device, ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index f0ba25936..e160d3255 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -127,15 +127,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size - max_num_seqs = vllm_config.scheduler_config.max_num_seqs if self.use_full_cuda_graph and self.fa_aot_schedule: - # Times 4 due to: - # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 - # For some tests max_cudagraph_size > max_num_seqs, - # so we need to use the larger one. self.scheduler_metadata = torch.zeros( - max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1, + vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, device=self.device, )