[Model Runner V2] Minor simplification for DCP (#34786)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-18 11:04:53 -08:00
committed by GitHub
parent 0e60c925cf
commit 95be2a7f22
5 changed files with 111 additions and 95 deletions

View File

@@ -12,7 +12,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
@@ -144,28 +143,6 @@ def build_slot_mappings_by_layer(
return slot_mappings_by_layer
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_kv_cache_interleave_size: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size <= 1:
return
local_seq_lens = get_dcp_local_seq_lens(
seq_lens[:num_reqs],
dcp_size=dcp_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
dcp_local_seq_lens[num_reqs:].zero_()
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
@@ -181,7 +158,6 @@ def build_attn_metadata(
dcp_local_seq_lens: torch.Tensor | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]

View File

@@ -4,7 +4,6 @@ from collections.abc import Iterable
import torch
from vllm.distributed import get_dcp_group
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -19,36 +18,29 @@ class BlockTables:
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
cp_kv_cache_interleave_size: int = 1,
cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
assert cp_kv_cache_interleave_size >= 1
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
try:
dcp = get_dcp_group()
self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group
except AssertionError:
self.dcp_world_size, self.dcp_rank = 1, 0
# TODO(wentao): PCP supprot
self.total_cp_world_size = self.dcp_world_size
self.total_cp_rank = self.dcp_rank
self.cp_size = cp_size
self.cp_rank = cp_rank
self.cp_interleave = cp_interleave
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
# with DCP, a request's KV is sharded across
# ranks, so one physical block on this rank
# corresponds to `block_size * total_cp_world_size`
# tokens in the global (unsharded) sequence.
virtual_block_size = block_size * self.total_cp_world_size
max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
# When using DCP, each request's KV cache is sharded among different ranks.
# As a result, one block on the current rank covers `block_size * cp_size`
# tokens in the full, global (unsharded) sequence.
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
@@ -149,9 +141,9 @@ class BlockTables:
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
TOTAL_CP_WORLD_SIZE=self.total_cp_world_size,
TOTAL_CP_RANK=self.total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
self.cp_rank,
CP_SIZE=self.cp_size,
CP_INTERLEAVE=self.cp_interleave,
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
@@ -204,9 +196,9 @@ def _compute_slot_mappings_kernel(
block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
cp_rank,
CP_SIZE: tl.constexpr,
CP_INTERLEAVE: tl.constexpr,
PAD_ID: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
):
@@ -225,7 +217,6 @@ def _compute_slot_mappings_kernel(
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
block_size = tl.load(block_sizes + group_id)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
req_state_idx = tl.load(idx_mapping + batch_idx)
start_idx = tl.load(query_start_loc + batch_idx)
@@ -233,26 +224,25 @@ def _compute_slot_mappings_kernel(
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // virtual_block_size
block_indices = positions // (block_size * CP_SIZE)
block_offsets = positions % (block_size * CP_SIZE)
block_numbers = tl.load(
block_table_ptr + req_state_idx * block_table_stride + block_indices
)
virtual_block_offsets = positions - block_indices * virtual_block_size
# determine whether the token is stored on this CP rank.
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
# mapping virture block offsets to local block offsets.
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)
if CP_SIZE == 1:
# Common case: Context parallelism is not used.
slot_ids = block_numbers * block_size + block_offsets
else:
# Context parallelism is used.
is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
remainder = block_offsets % CP_INTERLEAVE
local_offsets = rounds * CP_INTERLEAVE + remainder
slot_ids = block_numbers * block_size + local_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
# physical slot index
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def prepare_dcp_local_seq_lens(
dcp_local_seq_lens: torch.Tensor,
seq_lens: torch.Tensor,
num_reqs: int,
dcp_size: int,
dcp_rank: int,
cp_interleave: int,
) -> None:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if dcp_size == 1:
return
max_num_reqs = dcp_local_seq_lens.shape[0]
BLOCK_SIZE = 128
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
_dcp_local_seq_lens_kernel[(num_blocks,)](
dcp_local_seq_lens,
seq_lens,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE,
)
@triton.jit
def _dcp_local_seq_lens_kernel(
out_ptr,
seq_lens_ptr,
dcp_size,
dcp_rank,
cp_interleave,
num_reqs,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
# Distribute KV cache among different ranks, in a round-robin manner.
rounds = seq_lens // (dcp_size * cp_interleave)
remainder = seq_lens % (dcp_size * cp_interleave)
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
remainder = tl.minimum(remainder, cp_interleave)
local_seq_lens = rounds * cp_interleave + remainder
# For [num_reqs, max_num_reqs), pad with 0
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)

View File

@@ -10,7 +10,6 @@ from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed import get_dcp_group
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backend import AttentionMetadataBuilder
@@ -18,7 +17,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
prepare_dcp_local_seq_lens,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
@@ -259,22 +257,8 @@ def prepare_inputs_to_capture(
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
try:
dcp_group = get_dcp_group()
dcp_world_size = dcp_group.world_size
dcp_rank = dcp_group.rank_in_group
except AssertionError:
dcp_world_size = 1
dcp_rank = 0
if dcp_world_size > 1:
prepare_dcp_local_seq_lens(
input_buffers.dcp_local_seq_lens,
input_buffers.seq_lens,
num_reqs,
dcp_size=dcp_world_size,
dcp_rank=dcp_rank,
cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size,
)
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]

View File

@@ -33,10 +33,10 @@ from vllm.v1.worker.gpu.attn_utils import (
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
prepare_dcp_local_seq_lens,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import (
get_cudagraph_and_dp_padding,
@@ -192,6 +192,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
@@ -251,9 +257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
cp_kv_cache_interleave_size=(
self.parallel_config.cp_kv_cache_interleave_size
),
cp_size=self.dcp_size,
cp_rank=self.dcp_rank,
cp_interleave=self.cp_interleave,
)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
@@ -636,18 +642,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_size = self.parallel_config.decode_context_parallel_size
if dcp_size > 1:
if self.use_dcp:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens(
self.input_buffers.dcp_local_seq_lens,
seq_lens,
self.input_buffers.seq_lens,
num_reqs,
dcp_size=dcp_size,
dcp_rank=get_dcp_group().rank_in_group,
cp_kv_cache_interleave_size=(
self.parallel_config.cp_kv_cache_interleave_size
),
self.dcp_size,
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Prepare M-RoPE positions.
if self.uses_mrope:
@@ -696,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
dcp_local_seq_lens=dcp_local_seq_lens,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]