diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c098bfb48..c5d6798bd 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -22,9 +22,11 @@ from vllm.v1.attention.backends.fa_utils import ( get_flash_attn_version, is_flash_attn_varlen_func_available, ) +from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce from vllm.v1.attention.ops.merge_attn_states import merge_attn_states +from vllm.v1.worker.workspace import current_workspace_manager if is_flash_attn_varlen_func_available(): from vllm.v1.attention.backends.fa_utils import ( @@ -52,7 +54,6 @@ from vllm.v1.attention.backend import ( CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( - get_dcp_local_seq_lens, get_kv_cache_layout, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -356,6 +357,14 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad self.attention_config.flash_attn_max_num_splits_for_cuda_graph ) + if self.dcp_world_size > 1: + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self._dcp_context_kv_lens = torch.zeros( + max_num_reqs, + dtype=torch.int32, + device=self.device, + ) + # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: tuple[int, int] | None = None @@ -452,15 +461,18 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad prefix_scheduler_metadata = None if self.dcp_world_size > 1: - query_kv_lens = query_start_loc[1:] - query_start_loc[:-1] - dcp_context_kv_lens = seq_lens - query_kv_lens - - dcp_context_kv_lens = get_dcp_local_seq_lens( - dcp_context_kv_lens, + query_lens = query_start_loc[1:] - query_start_loc[:-1] + context_kv_lens = seq_lens - query_lens + local_context_kv_lens = get_dcp_local_seq_lens( + context_kv_lens, self.dcp_world_size, self.dcp_rank, self.cp_kv_cache_interleave_size, ) + self._dcp_context_kv_lens[:num_reqs] = local_context_kv_lens + self._dcp_context_kv_lens[num_reqs:] = 0 + dcp_context_kv_lens = self._dcp_context_kv_lens[:num_reqs] + # After DCP distribution, the maximum number of tokens for any rank is # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, # and I is cp_kv_cache_interleave_size. @@ -637,6 +649,10 @@ class FlashAttentionImpl(AttentionImpl): ) self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs + self._dcp_dtype: torch.dtype | None = None + if vllm_config is not None and self.dcp_world_size > 1: + self._dcp_dtype = vllm_config.model_config.dtype + def forward( self, layer: torch.nn.Module, @@ -862,11 +878,18 @@ class FlashAttentionImpl(AttentionImpl): sliding_window_size = ( list(self.sliding_window) if self.sliding_window is not None else None ) + n = query_across_dcp.shape[0] + (dcp_context_out,) = current_workspace_manager().get_simultaneous( + ( + (n, self.num_heads * self.dcp_world_size, self.head_size), + self._dcp_dtype, + ), + ) context_attn_out, context_lse = flash_attn_varlen_func( q=query_across_dcp, k=key_cache, v=value_cache, - out=None, + out=dcp_context_out, cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, seqused_k=attn_metadata.dcp_context_kv_lens, @@ -894,11 +917,14 @@ class FlashAttentionImpl(AttentionImpl): ) context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + (dcp_query_out,) = current_workspace_manager().get_simultaneous( + ((query.shape[0], self.num_heads, self.head_size), self._dcp_dtype), + ) query_attn_out, query_lse = flash_attn_varlen_func( q=query, k=key, v=value, - out=None, + out=dcp_query_out, cu_seqlens_q=cu_seqlens_q, max_seqlen_q=max_seqlen_q, cu_seqlens_k=cu_seqlens_q,