[Bugfix][DCP] Fix CUDA graph capture for Decode Context Parallelism (#36070)
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -22,9 +22,11 @@ from vllm.v1.attention.backends.fa_utils import (
|
||||
get_flash_attn_version,
|
||||
is_flash_attn_varlen_func_available,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
|
||||
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.v1.attention.backends.fa_utils import (
|
||||
@@ -52,7 +54,6 @@ from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@@ -356,6 +357,14 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||
)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self._dcp_context_kv_lens = torch.zeros(
|
||||
max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: tuple[int, int] | None = None
|
||||
@@ -452,15 +461,18 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
dcp_context_kv_lens = seq_lens - query_kv_lens
|
||||
|
||||
dcp_context_kv_lens = get_dcp_local_seq_lens(
|
||||
dcp_context_kv_lens,
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
context_kv_lens = seq_lens - query_lens
|
||||
local_context_kv_lens = get_dcp_local_seq_lens(
|
||||
context_kv_lens,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.cp_kv_cache_interleave_size,
|
||||
)
|
||||
self._dcp_context_kv_lens[:num_reqs] = local_context_kv_lens
|
||||
self._dcp_context_kv_lens[num_reqs:] = 0
|
||||
dcp_context_kv_lens = self._dcp_context_kv_lens[:num_reqs]
|
||||
|
||||
# After DCP distribution, the maximum number of tokens for any rank is
|
||||
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
|
||||
# and I is cp_kv_cache_interleave_size.
|
||||
@@ -637,6 +649,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
|
||||
|
||||
self._dcp_dtype: torch.dtype | None = None
|
||||
if vllm_config is not None and self.dcp_world_size > 1:
|
||||
self._dcp_dtype = vllm_config.model_config.dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -862,11 +878,18 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
sliding_window_size = (
|
||||
list(self.sliding_window) if self.sliding_window is not None else None
|
||||
)
|
||||
n = query_across_dcp.shape[0]
|
||||
(dcp_context_out,) = current_workspace_manager().get_simultaneous(
|
||||
(
|
||||
(n, self.num_heads * self.dcp_world_size, self.head_size),
|
||||
self._dcp_dtype,
|
||||
),
|
||||
)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=None,
|
||||
out=dcp_context_out,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||
@@ -894,11 +917,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
|
||||
|
||||
(dcp_query_out,) = current_workspace_manager().get_simultaneous(
|
||||
((query.shape[0], self.num_heads, self.head_size), self._dcp_dtype),
|
||||
)
|
||||
query_attn_out, query_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
out=None,
|
||||
out=dcp_query_out,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_k=cu_seqlens_q,
|
||||
|
||||
Reference in New Issue
Block a user