[Model Runner V2] Minor simplification for DCP (#34786)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -12,7 +12,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheConfig,
|
||||
@@ -144,28 +143,6 @@ def build_slot_mappings_by_layer(
|
||||
return slot_mappings_by_layer
|
||||
|
||||
|
||||
def prepare_dcp_local_seq_lens(
|
||||
dcp_local_seq_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_reqs: int,
|
||||
dcp_size: int,
|
||||
dcp_rank: int,
|
||||
cp_kv_cache_interleave_size: int,
|
||||
) -> None:
|
||||
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
|
||||
if dcp_size <= 1:
|
||||
return
|
||||
|
||||
local_seq_lens = get_dcp_local_seq_lens(
|
||||
seq_lens[:num_reqs],
|
||||
dcp_size=dcp_size,
|
||||
dcp_rank=dcp_rank,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
|
||||
dcp_local_seq_lens[num_reqs:].zero_()
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
@@ -181,7 +158,6 @@ def build_attn_metadata(
|
||||
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||
) -> dict[str, Any]:
|
||||
seq_lens = seq_lens[:num_reqs]
|
||||
|
||||
if dcp_local_seq_lens is not None:
|
||||
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
@@ -19,36 +18,29 @@ class BlockTables:
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
cp_size: int = 1,
|
||||
cp_rank: int = 0,
|
||||
cp_interleave: int = 1,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
assert cp_kv_cache_interleave_size >= 1
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
try:
|
||||
dcp = get_dcp_group()
|
||||
self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group
|
||||
except AssertionError:
|
||||
self.dcp_world_size, self.dcp_rank = 1, 0
|
||||
# TODO(wentao): PCP supprot
|
||||
self.total_cp_world_size = self.dcp_world_size
|
||||
self.total_cp_rank = self.dcp_rank
|
||||
self.cp_size = cp_size
|
||||
self.cp_rank = cp_rank
|
||||
self.cp_interleave = cp_interleave
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[StagedWriteTensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
# with DCP, a request's KV is sharded across
|
||||
# ranks, so one physical block on this rank
|
||||
# corresponds to `block_size * total_cp_world_size`
|
||||
# tokens in the global (unsharded) sequence.
|
||||
virtual_block_size = block_size * self.total_cp_world_size
|
||||
max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
|
||||
# When using DCP, each request's KV cache is sharded among different ranks.
|
||||
# As a result, one block on the current rank covers `block_size * cp_size`
|
||||
# tokens in the full, global (unsharded) sequence.
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
|
||||
block_table = StagedWriteTensor(
|
||||
(self.max_num_reqs, max_num_blocks),
|
||||
dtype=torch.int32,
|
||||
@@ -149,9 +141,9 @@ class BlockTables:
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
TOTAL_CP_WORLD_SIZE=self.total_cp_world_size,
|
||||
TOTAL_CP_RANK=self.total_cp_rank,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
|
||||
self.cp_rank,
|
||||
CP_SIZE=self.cp_size,
|
||||
CP_INTERLEAVE=self.cp_interleave,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
TRITON_BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
@@ -204,9 +196,9 @@ def _compute_slot_mappings_kernel(
|
||||
block_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
TOTAL_CP_WORLD_SIZE: tl.constexpr,
|
||||
TOTAL_CP_RANK: tl.constexpr,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
|
||||
cp_rank,
|
||||
CP_SIZE: tl.constexpr,
|
||||
CP_INTERLEAVE: tl.constexpr,
|
||||
PAD_ID: tl.constexpr,
|
||||
TRITON_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
@@ -225,7 +217,6 @@ def _compute_slot_mappings_kernel(
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
block_size = tl.load(block_sizes + group_id)
|
||||
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
|
||||
|
||||
req_state_idx = tl.load(idx_mapping + batch_idx)
|
||||
start_idx = tl.load(query_start_loc + batch_idx)
|
||||
@@ -233,26 +224,25 @@ def _compute_slot_mappings_kernel(
|
||||
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // virtual_block_size
|
||||
|
||||
block_indices = positions // (block_size * CP_SIZE)
|
||||
block_offsets = positions % (block_size * CP_SIZE)
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_state_idx * block_table_stride + block_indices
|
||||
)
|
||||
virtual_block_offsets = positions - block_indices * virtual_block_size
|
||||
|
||||
# determine whether the token is stored on this CP rank.
|
||||
is_local = (
|
||||
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
|
||||
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
|
||||
# mapping virture block offsets to local block offsets.
|
||||
local_block_offsets = (
|
||||
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
|
||||
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
|
||||
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
|
||||
)
|
||||
if CP_SIZE == 1:
|
||||
# Common case: Context parallelism is not used.
|
||||
slot_ids = block_numbers * block_size + block_offsets
|
||||
else:
|
||||
# Context parallelism is used.
|
||||
is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
|
||||
rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
|
||||
remainder = block_offsets % CP_INTERLEAVE
|
||||
local_offsets = rounds * CP_INTERLEAVE + remainder
|
||||
slot_ids = block_numbers * block_size + local_offsets
|
||||
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
|
||||
|
||||
# physical slot index
|
||||
slot_ids = block_numbers * block_size + local_block_offsets
|
||||
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
|
||||
61
vllm/v1/worker/gpu/cp_utils.py
Normal file
61
vllm/v1/worker/gpu/cp_utils.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def prepare_dcp_local_seq_lens(
|
||||
dcp_local_seq_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_reqs: int,
|
||||
dcp_size: int,
|
||||
dcp_rank: int,
|
||||
cp_interleave: int,
|
||||
) -> None:
|
||||
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
|
||||
if dcp_size == 1:
|
||||
return
|
||||
|
||||
max_num_reqs = dcp_local_seq_lens.shape[0]
|
||||
BLOCK_SIZE = 128
|
||||
num_blocks = triton.cdiv(max_num_reqs, BLOCK_SIZE)
|
||||
_dcp_local_seq_lens_kernel[(num_blocks,)](
|
||||
dcp_local_seq_lens,
|
||||
seq_lens,
|
||||
dcp_size,
|
||||
dcp_rank,
|
||||
cp_interleave,
|
||||
num_reqs,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dcp_local_seq_lens_kernel(
|
||||
out_ptr,
|
||||
seq_lens_ptr,
|
||||
dcp_size,
|
||||
dcp_rank,
|
||||
cp_interleave,
|
||||
num_reqs,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
block = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
seq_lens = tl.load(seq_lens_ptr + block, mask=block < num_reqs)
|
||||
|
||||
# Distribute KV cache among different ranks, in a round-robin manner.
|
||||
rounds = seq_lens // (dcp_size * cp_interleave)
|
||||
remainder = seq_lens % (dcp_size * cp_interleave)
|
||||
|
||||
remainder = tl.maximum(remainder - dcp_rank * cp_interleave, 0)
|
||||
remainder = tl.minimum(remainder, cp_interleave)
|
||||
local_seq_lens = rounds * cp_interleave + remainder
|
||||
|
||||
# For [num_reqs, max_num_reqs), pad with 0
|
||||
local_seq_lens = tl.where(block < num_reqs, local_seq_lens, 0)
|
||||
tl.store(out_ptr + block, local_seq_lens, mask=block < max_num_reqs)
|
||||
@@ -10,7 +10,6 @@ from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
@@ -18,7 +17,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
prepare_dcp_local_seq_lens,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
@@ -259,22 +257,8 @@ def prepare_inputs_to_capture(
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
try:
|
||||
dcp_group = get_dcp_group()
|
||||
dcp_world_size = dcp_group.world_size
|
||||
dcp_rank = dcp_group.rank_in_group
|
||||
except AssertionError:
|
||||
dcp_world_size = 1
|
||||
dcp_rank = 0
|
||||
if dcp_world_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_buffers.seq_lens,
|
||||
num_reqs,
|
||||
dcp_size=dcp_world_size,
|
||||
dcp_rank=dcp_rank,
|
||||
cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size,
|
||||
)
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
|
||||
@@ -33,10 +33,10 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
init_kv_cache,
|
||||
prepare_dcp_local_seq_lens,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_cudagraph_and_dp_padding,
|
||||
@@ -192,6 +192,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
# Decode context parallelism.
|
||||
self.dcp_size = self.parallel_config.decode_context_parallel_size
|
||||
self.use_dcp = self.dcp_size > 1
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
|
||||
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
@@ -251,9 +257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
cp_kv_cache_interleave_size=(
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
),
|
||||
cp_size=self.dcp_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_interleave=self.cp_interleave,
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
@@ -636,18 +642,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
dcp_size = self.parallel_config.decode_context_parallel_size
|
||||
if dcp_size > 1:
|
||||
if self.use_dcp:
|
||||
# Prepare dcp local seq_lens.
|
||||
prepare_dcp_local_seq_lens(
|
||||
self.input_buffers.dcp_local_seq_lens,
|
||||
seq_lens,
|
||||
self.input_buffers.seq_lens,
|
||||
num_reqs,
|
||||
dcp_size=dcp_size,
|
||||
dcp_rank=get_dcp_group().rank_in_group,
|
||||
cp_kv_cache_interleave_size=(
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
),
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
self.cp_interleave,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
@@ -696,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
|
||||
Reference in New Issue
Block a user