diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 57828924a..468e77113 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -12,7 +12,6 @@ from vllm.v1.attention.backend import ( AttentionMetadataBuilder, CommonAttentionMetadata, ) -from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheConfig, @@ -144,28 +143,6 @@ def build_slot_mappings_by_layer( return slot_mappings_by_layer -def prepare_dcp_local_seq_lens( - dcp_local_seq_lens: torch.Tensor, - seq_lens: torch.Tensor, - num_reqs: int, - dcp_size: int, - dcp_rank: int, - cp_kv_cache_interleave_size: int, -) -> None: - """Populate the persistent DCP local seq_lens buffer (CUDA graph safe).""" - if dcp_size <= 1: - return - - local_seq_lens = get_dcp_local_seq_lens( - seq_lens[:num_reqs], - dcp_size=dcp_size, - dcp_rank=dcp_rank, - cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, - ) - dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True) - dcp_local_seq_lens[num_reqs:].zero_() - - def build_attn_metadata( attn_metadata_builders: list[AttentionMetadataBuilder], num_reqs: int, @@ -181,7 +158,6 @@ def build_attn_metadata( dcp_local_seq_lens: torch.Tensor | None = None, ) -> dict[str, Any]: seq_lens = seq_lens[:num_reqs] - if dcp_local_seq_lens is not None: dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs] diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index a172bf225..9dfdf834d 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -4,7 +4,6 @@ from collections.abc import Iterable import torch -from vllm.distributed import get_dcp_group from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -19,36 +18,29 @@ class BlockTables: max_num_batched_tokens: int, max_model_len: int, device: torch.device, - cp_kv_cache_interleave_size: int = 1, + cp_size: int = 1, + cp_rank: int = 0, + cp_interleave: int = 1, ): self.block_sizes = block_sizes self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.max_model_len = max_model_len self.device = device - assert cp_kv_cache_interleave_size >= 1 - self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size - try: - dcp = get_dcp_group() - self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group - except AssertionError: - self.dcp_world_size, self.dcp_rank = 1, 0 - # TODO(wentao): PCP supprot - self.total_cp_world_size = self.dcp_world_size - self.total_cp_rank = self.dcp_rank + self.cp_size = cp_size + self.cp_rank = cp_rank + self.cp_interleave = cp_interleave self.num_kv_cache_groups = len(self.block_sizes) # num_kv_cache_groups x [max_num_reqs, max_num_blocks] self.block_tables: list[StagedWriteTensor] = [] for i in range(self.num_kv_cache_groups): block_size = self.block_sizes[i] - # with DCP, a request's KV is sharded across - # ranks, so one physical block on this rank - # corresponds to `block_size * total_cp_world_size` - # tokens in the global (unsharded) sequence. - virtual_block_size = block_size * self.total_cp_world_size - max_num_blocks = cdiv(self.max_model_len, virtual_block_size) + # When using DCP, each request's KV cache is sharded among different ranks. + # As a result, one block on the current rank covers `block_size * cp_size` + # tokens in the full, global (unsharded) sequence. + max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size) block_table = StagedWriteTensor( (self.max_num_reqs, max_num_blocks), dtype=torch.int32, @@ -149,9 +141,9 @@ class BlockTables: self.block_sizes_tensor, self.slot_mappings, self.slot_mappings.stride(0), - TOTAL_CP_WORLD_SIZE=self.total_cp_world_size, - TOTAL_CP_RANK=self.total_cp_rank, - CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size, + self.cp_rank, + CP_SIZE=self.cp_size, + CP_INTERLEAVE=self.cp_interleave, PAD_ID=PAD_SLOT_ID, TRITON_BLOCK_SIZE=1024, # type: ignore ) @@ -204,9 +196,9 @@ def _compute_slot_mappings_kernel( block_sizes, # [num_kv_cache_groups] slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] slot_mappings_stride, - TOTAL_CP_WORLD_SIZE: tl.constexpr, - TOTAL_CP_RANK: tl.constexpr, - CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr, + cp_rank, + CP_SIZE: tl.constexpr, + CP_INTERLEAVE: tl.constexpr, PAD_ID: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, ): @@ -225,7 +217,6 @@ def _compute_slot_mappings_kernel( block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_stride = tl.load(block_table_strides + group_id) block_size = tl.load(block_sizes + group_id) - virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE req_state_idx = tl.load(idx_mapping + batch_idx) start_idx = tl.load(query_start_loc + batch_idx) @@ -233,26 +224,25 @@ def _compute_slot_mappings_kernel( for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE): offset = i + tl.arange(0, TRITON_BLOCK_SIZE) positions = tl.load(pos + offset, mask=offset < end_idx, other=0) - block_indices = positions // virtual_block_size + + block_indices = positions // (block_size * CP_SIZE) + block_offsets = positions % (block_size * CP_SIZE) block_numbers = tl.load( block_table_ptr + req_state_idx * block_table_stride + block_indices ) - virtual_block_offsets = positions - block_indices * virtual_block_size - # determine whether the token is stored on this CP rank. - is_local = ( - virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE - ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK - # mapping virture block offsets to local block offsets. - local_block_offsets = ( - virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE) - ) * CP_KV_CACHE_INTERLEAVE_SIZE + ( - virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE - ) + if CP_SIZE == 1: + # Common case: Context parallelism is not used. + slot_ids = block_numbers * block_size + block_offsets + else: + # Context parallelism is used. + is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank + rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE) + remainder = block_offsets % CP_INTERLEAVE + local_offsets = rounds * CP_INTERLEAVE + remainder + slot_ids = block_numbers * block_size + local_offsets + slot_ids = tl.where(is_local, slot_ids, PAD_ID) - # physical slot index - slot_ids = block_numbers * block_size + local_block_offsets - slot_ids = tl.where(is_local, slot_ids, PAD_ID) tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) diff --git a/vllm/v1/worker/gpu/cp_utils.py b/vllm/v1/worker/gpu/cp_utils.py new file mode 100644 index 000000000..6dd8fd347 --- /dev/null +++ b/vllm/v1/worker/gpu/cp_utils.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +def prepare_dcp_local_seq_lens( + dcp_local_seq_lens: torch.Tensor, + seq_lens: torch.Tensor, + num_reqs: int, + dcp_size: int, + dcp_rank: int, + cp_interleave: int, +) -> None: + """Populate the persistent DCP local seq_lens buffer (CUDA graph safe).""" + if dcp_size == 1: + return + + max_num_reqs = dcp_local_seq_lens.shape[0] + BLOCK_SIZE = 128 + num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE) + _dcp_local_seq_lens_kernel[(num_blocks,)]( + dcp_local_seq_lens, + seq_lens, + dcp_size, + dcp_rank, + cp_interleave, + num_reqs, + max_num_reqs, + BLOCK_SIZE, + ) + + +@triton.jit +def _dcp_local_seq_lens_kernel( + out_ptr, + seq_lens_ptr, + dcp_size, + dcp_rank, + cp_interleave, + num_reqs, + max_num_reqs, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs) + + # Distribute KV cache among different ranks, in a round-robin manner. + rounds = seq_lens // (dcp_size * cp_interleave) + remainder = seq_lens % (dcp_size * cp_interleave) + + remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0) + remainder = tl.minimum(remainder, cp_interleave) + local_seq_lens = rounds * cp_interleave + remainder + + # For [num_reqs, max_num_reqs), pad with 0 + local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0) + tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 41a45ac87..0c5a93abc 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -10,7 +10,6 @@ from tqdm import tqdm from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.distributed import get_dcp_group from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import set_forward_context from vllm.v1.attention.backend import AttentionMetadataBuilder @@ -18,7 +17,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, build_slot_mappings_by_layer, - prepare_dcp_local_seq_lens, ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp @@ -259,22 +257,8 @@ def prepare_inputs_to_capture( input_buffers.seq_lens[:num_reqs] = num_tokens input_buffers.seq_lens[num_reqs:] = 0 - try: - dcp_group = get_dcp_group() - dcp_world_size = dcp_group.world_size - dcp_rank = dcp_group.rank_in_group - except AssertionError: - dcp_world_size = 1 - dcp_rank = 0 - if dcp_world_size > 1: - prepare_dcp_local_seq_lens( - input_buffers.dcp_local_seq_lens, - input_buffers.seq_lens, - num_reqs, - dcp_size=dcp_world_size, - dcp_rank=dcp_rank, - cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size, - ) + input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens + input_buffers.dcp_local_seq_lens[num_reqs:] = 0 input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] slot_mappings = block_tables.slot_mappings[:, :num_tokens] diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index e9f9d868f..be620b0cc 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -33,10 +33,10 @@ from vllm.v1.worker.gpu.attn_utils import ( get_kv_cache_spec, init_attn_backend, init_kv_cache, - prepare_dcp_local_seq_lens, ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu +from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.dp_utils import ( get_cudagraph_and_dp_padding, @@ -192,6 +192,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.is_first_pp_rank = True self.is_last_pp_rank = True + # Decode context parallelism. + self.dcp_size = self.parallel_config.decode_context_parallel_size + self.use_dcp = self.dcp_size > 1 + self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0 + self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size + def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len self.req_states.max_model_len = max_model_len @@ -251,9 +257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_num_batched_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, - cp_kv_cache_interleave_size=( - self.parallel_config.cp_kv_cache_interleave_size - ), + cp_size=self.dcp_size, + cp_rank=self.dcp_rank, + cp_interleave=self.cp_interleave, ) self.attn_backends, self.attn_metadata_builders = init_attn_backend( @@ -636,18 +642,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) seq_lens = self.input_buffers.seq_lens[:num_reqs] - dcp_size = self.parallel_config.decode_context_parallel_size - if dcp_size > 1: + if self.use_dcp: + # Prepare dcp local seq_lens. prepare_dcp_local_seq_lens( self.input_buffers.dcp_local_seq_lens, - seq_lens, + self.input_buffers.seq_lens, num_reqs, - dcp_size=dcp_size, - dcp_rank=get_dcp_group().rank_in_group, - cp_kv_cache_interleave_size=( - self.parallel_config.cp_kv_cache_interleave_size - ), + self.dcp_size, + self.dcp_rank, + self.cp_interleave, ) + dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] # Prepare M-RoPE positions. if self.uses_mrope: @@ -696,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_tables=block_tables, slot_mappings=slot_mappings, kv_cache_config=self.kv_cache_config, - dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens, + dcp_local_seq_lens=dcp_local_seq_lens, ) input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]