diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 8a08fba1e..57828924a 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -12,6 +12,7 @@ from vllm.v1.attention.backend import ( AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheConfig, @@ -143,6 +144,28 @@ def build_slot_mappings_by_layer( return slot_mappings_by_layer +def prepare_dcp_local_seq_lens( + dcp_local_seq_lens: torch.Tensor, + seq_lens: torch.Tensor, + num_reqs: int, + dcp_size: int, + dcp_rank: int, + cp_kv_cache_interleave_size: int, +) -> None: + """Populate the persistent DCP local seq_lens buffer (CUDA graph safe).""" + if dcp_size <= 1: + return + + local_seq_lens = get_dcp_local_seq_lens( + seq_lens[:num_reqs], + dcp_size=dcp_size, + dcp_rank=dcp_rank, + cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, + ) + dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True) + dcp_local_seq_lens[num_reqs:].zero_() + + def build_attn_metadata( attn_metadata_builders: list[AttentionMetadataBuilder], num_reqs: int, @@ -155,9 +178,13 @@ def build_attn_metadata( block_tables: Sequence[torch.Tensor], slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, + dcp_local_seq_lens: torch.Tensor | None = None, ) -> dict[str, Any]: seq_lens = seq_lens[:num_reqs] + if dcp_local_seq_lens is not None: + dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs] + attn_metadata: dict[str, Any] = {} kv_cache_groups = kv_cache_config.kv_cache_groups for i, kv_cache_spec in enumerate(kv_cache_groups): @@ -175,6 +202,7 @@ def build_attn_metadata( block_table_tensor=block_table, slot_mapping=slot_mapping, causal=True, + dcp_local_seq_lens=dcp_local_seq_lens, ) attn_metadata_builder = attn_metadata_builders[i] diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index 3f54fa56e..a172bf225 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -4,6 +4,7 @@ from collections.abc import Iterable import torch +from vllm.distributed import get_dcp_group from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -18,19 +19,36 @@ class BlockTables: max_num_batched_tokens: int, max_model_len: int, device: torch.device, + cp_kv_cache_interleave_size: int = 1, ): self.block_sizes = block_sizes self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.max_model_len = max_model_len self.device = device + assert cp_kv_cache_interleave_size >= 1 + self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size + + try: + dcp = get_dcp_group() + self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group + except AssertionError: + self.dcp_world_size, self.dcp_rank = 1, 0 + # TODO(wentao): PCP supprot + self.total_cp_world_size = self.dcp_world_size + self.total_cp_rank = self.dcp_rank self.num_kv_cache_groups = len(self.block_sizes) # num_kv_cache_groups x [max_num_reqs, max_num_blocks] self.block_tables: list[StagedWriteTensor] = [] for i in range(self.num_kv_cache_groups): block_size = self.block_sizes[i] - max_num_blocks = cdiv(self.max_model_len, block_size) + # with DCP, a request's KV is sharded across + # ranks, so one physical block on this rank + # corresponds to `block_size * total_cp_world_size` + # tokens in the global (unsharded) sequence. + virtual_block_size = block_size * self.total_cp_world_size + max_num_blocks = cdiv(self.max_model_len, virtual_block_size) block_table = StagedWriteTensor( (self.max_num_reqs, max_num_blocks), dtype=torch.int32, @@ -131,6 +149,9 @@ class BlockTables: self.block_sizes_tensor, self.slot_mappings, self.slot_mappings.stride(0), + TOTAL_CP_WORLD_SIZE=self.total_cp_world_size, + TOTAL_CP_RANK=self.total_cp_rank, + CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size, PAD_ID=PAD_SLOT_ID, TRITON_BLOCK_SIZE=1024, # type: ignore ) @@ -183,6 +204,9 @@ def _compute_slot_mappings_kernel( block_sizes, # [num_kv_cache_groups] slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] slot_mappings_stride, + TOTAL_CP_WORLD_SIZE: tl.constexpr, + TOTAL_CP_RANK: tl.constexpr, + CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr, PAD_ID: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, ): @@ -201,6 +225,7 @@ def _compute_slot_mappings_kernel( block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_stride = tl.load(block_table_strides + group_id) block_size = tl.load(block_sizes + group_id) + virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE req_state_idx = tl.load(idx_mapping + batch_idx) start_idx = tl.load(query_start_loc + batch_idx) @@ -208,11 +233,26 @@ def _compute_slot_mappings_kernel( for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE): offset = i + tl.arange(0, TRITON_BLOCK_SIZE) positions = tl.load(pos + offset, mask=offset < end_idx, other=0) - block_indices = positions // block_size + block_indices = positions // virtual_block_size block_numbers = tl.load( block_table_ptr + req_state_idx * block_table_stride + block_indices ) - slot_ids = block_numbers * block_size + positions % block_size + virtual_block_offsets = positions - block_indices * virtual_block_size + + # determine whether the token is stored on this CP rank. + is_local = ( + virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE + ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK + # mapping virture block offsets to local block offsets. + local_block_offsets = ( + virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE) + ) * CP_KV_CACHE_INTERLEAVE_SIZE + ( + virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE + ) + + # physical slot index + slot_ids = block_numbers * block_size + local_block_offsets + slot_ids = tl.where(is_local, slot_ids, PAD_ID) tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index d5a22d6a0..41a45ac87 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -10,6 +10,7 @@ from tqdm import tqdm from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode +from vllm.distributed import get_dcp_group from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import set_forward_context from vllm.v1.attention.backend import AttentionMetadataBuilder @@ -17,6 +18,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, build_slot_mappings_by_layer, + prepare_dcp_local_seq_lens, ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp @@ -257,6 +259,23 @@ def prepare_inputs_to_capture( input_buffers.seq_lens[:num_reqs] = num_tokens input_buffers.seq_lens[num_reqs:] = 0 + try: + dcp_group = get_dcp_group() + dcp_world_size = dcp_group.world_size + dcp_rank = dcp_group.rank_in_group + except AssertionError: + dcp_world_size = 1 + dcp_rank = 0 + if dcp_world_size > 1: + prepare_dcp_local_seq_lens( + input_buffers.dcp_local_seq_lens, + input_buffers.seq_lens, + num_reqs, + dcp_size=dcp_world_size, + dcp_rank=dcp_rank, + cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size, + ) + input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] slot_mappings = block_tables.slot_mappings[:, :num_tokens] slot_mappings_by_layer = build_slot_mappings_by_layer( @@ -275,5 +294,6 @@ def prepare_inputs_to_capture( block_tables=input_block_tables, slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, + dcp_local_seq_lens=input_buffers.dcp_local_seq_lens, ) return attn_metadata, slot_mappings_by_layer diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index bdb67be11..a15da926d 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -27,6 +27,10 @@ class InputBuffers: max_num_reqs + 1, dtype=torch.int32, device=device ) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) + # DCP: per-request local seq_lens buffer + self.dcp_local_seq_lens = torch.zeros( + max_num_reqs, dtype=torch.int32, device=device + ) @dataclass diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 8cca3cb46..2c50ea15f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -11,6 +11,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import ( + get_dcp_group, get_pp_group, prepare_communication_buffer_for_model, ) @@ -24,6 +25,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -31,6 +33,7 @@ from vllm.v1.worker.gpu.attn_utils import ( get_kv_cache_spec, init_attn_backend, init_kv_cache, + prepare_dcp_local_seq_lens, ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu @@ -248,11 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_num_batched_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, + cp_kv_cache_interleave_size=( + self.parallel_config.cp_kv_cache_interleave_size + ), ) self.attn_backends, self.attn_metadata_builders = init_attn_backend( self.kv_cache_config, self.vllm_config, self.device ) + check_attention_cp_compatibility(self.vllm_config) if self.do_spec_decode: # HACK(woosuk) self.speculator.set_attn( @@ -294,6 +301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, + dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens, ) input_batch.attn_metadata = attn_metadata input_batch.slot_mappings = slot_mappings_by_layer @@ -627,6 +635,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) seq_lens = self.input_buffers.seq_lens[:num_reqs] + dcp_size = self.parallel_config.decode_context_parallel_size + if dcp_size > 1: + prepare_dcp_local_seq_lens( + self.input_buffers.dcp_local_seq_lens, + seq_lens, + num_reqs, + dcp_size=dcp_size, + dcp_rank=get_dcp_group().rank_in_group, + cp_kv_cache_interleave_size=( + self.parallel_config.cp_kv_cache_interleave_size + ), + ) + # Prepare M-RoPE positions. if self.uses_mrope: self.mrope_states.prepare_mrope_positions( @@ -674,6 +695,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, + dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens, ) input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]