[Bugfix][DCP] Fix CUDA graph capture for Decode Context Parallelism (#36070)

Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
sungsoo ha
2026-03-30 17:20:43 -07:00
committed by GitHub
parent bb51d5b40d
commit 4ac227222f

View File

@@ -22,9 +22,11 @@ from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.worker.workspace import current_workspace_manager
if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import (
@@ -52,7 +54,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -356,6 +357,14 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
if self.dcp_world_size > 1:
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self._dcp_context_kv_lens = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: tuple[int, int] | None = None
@@ -452,15 +461,18 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata = None
if self.dcp_world_size > 1:
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
dcp_context_kv_lens = seq_lens - query_kv_lens
dcp_context_kv_lens = get_dcp_local_seq_lens(
dcp_context_kv_lens,
query_lens = query_start_loc[1:] - query_start_loc[:-1]
context_kv_lens = seq_lens - query_lens
local_context_kv_lens = get_dcp_local_seq_lens(
context_kv_lens,
self.dcp_world_size,
self.dcp_rank,
self.cp_kv_cache_interleave_size,
)
self._dcp_context_kv_lens[:num_reqs] = local_context_kv_lens
self._dcp_context_kv_lens[num_reqs:] = 0
dcp_context_kv_lens = self._dcp_context_kv_lens[:num_reqs]
# After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size.
@@ -637,6 +649,10 @@ class FlashAttentionImpl(AttentionImpl):
)
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
self._dcp_dtype: torch.dtype | None = None
if vllm_config is not None and self.dcp_world_size > 1:
self._dcp_dtype = vllm_config.model_config.dtype
def forward(
self,
layer: torch.nn.Module,
@@ -862,11 +878,18 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
n = query_across_dcp.shape[0]
(dcp_context_out,) = current_workspace_manager().get_simultaneous(
(
(n, self.num_heads * self.dcp_world_size, self.head_size),
self._dcp_dtype,
),
)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
out=dcp_context_out,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
@@ -894,11 +917,14 @@ class FlashAttentionImpl(AttentionImpl):
)
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
(dcp_query_out,) = current_workspace_manager().get_simultaneous(
((query.shape[0], self.num_heads, self.head_size), self._dcp_dtype),
)
query_attn_out, query_lse = flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
out=dcp_query_out,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,