[Feature] Decode Context Parallel support for GPU model runner v2 (#34179)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheConfig,
|
||||
@@ -143,6 +144,28 @@ def build_slot_mappings_by_layer(
|
||||
return slot_mappings_by_layer
|
||||
|
||||
|
||||
def prepare_dcp_local_seq_lens(
|
||||
dcp_local_seq_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_reqs: int,
|
||||
dcp_size: int,
|
||||
dcp_rank: int,
|
||||
cp_kv_cache_interleave_size: int,
|
||||
) -> None:
|
||||
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
|
||||
if dcp_size <= 1:
|
||||
return
|
||||
|
||||
local_seq_lens = get_dcp_local_seq_lens(
|
||||
seq_lens[:num_reqs],
|
||||
dcp_size=dcp_size,
|
||||
dcp_rank=dcp_rank,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens[:num_reqs].copy_(local_seq_lens, non_blocking=True)
|
||||
dcp_local_seq_lens[num_reqs:].zero_()
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
@@ -155,9 +178,13 @@ def build_attn_metadata(
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||
) -> dict[str, Any]:
|
||||
seq_lens = seq_lens[:num_reqs]
|
||||
|
||||
if dcp_local_seq_lens is not None:
|
||||
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
@@ -175,6 +202,7 @@ def build_attn_metadata(
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
@@ -18,19 +19,36 @@ class BlockTables:
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
assert cp_kv_cache_interleave_size >= 1
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
try:
|
||||
dcp = get_dcp_group()
|
||||
self.dcp_world_size, self.dcp_rank = dcp.world_size, dcp.rank_in_group
|
||||
except AssertionError:
|
||||
self.dcp_world_size, self.dcp_rank = 1, 0
|
||||
# TODO(wentao): PCP supprot
|
||||
self.total_cp_world_size = self.dcp_world_size
|
||||
self.total_cp_rank = self.dcp_rank
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[StagedWriteTensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
# with DCP, a request's KV is sharded across
|
||||
# ranks, so one physical block on this rank
|
||||
# corresponds to `block_size * total_cp_world_size`
|
||||
# tokens in the global (unsharded) sequence.
|
||||
virtual_block_size = block_size * self.total_cp_world_size
|
||||
max_num_blocks = cdiv(self.max_model_len, virtual_block_size)
|
||||
block_table = StagedWriteTensor(
|
||||
(self.max_num_reqs, max_num_blocks),
|
||||
dtype=torch.int32,
|
||||
@@ -131,6 +149,9 @@ class BlockTables:
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
TOTAL_CP_WORLD_SIZE=self.total_cp_world_size,
|
||||
TOTAL_CP_RANK=self.total_cp_rank,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
TRITON_BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
@@ -183,6 +204,9 @@ def _compute_slot_mappings_kernel(
|
||||
block_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
TOTAL_CP_WORLD_SIZE: tl.constexpr,
|
||||
TOTAL_CP_RANK: tl.constexpr,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
|
||||
PAD_ID: tl.constexpr,
|
||||
TRITON_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
@@ -201,6 +225,7 @@ def _compute_slot_mappings_kernel(
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
block_size = tl.load(block_sizes + group_id)
|
||||
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
|
||||
|
||||
req_state_idx = tl.load(idx_mapping + batch_idx)
|
||||
start_idx = tl.load(query_start_loc + batch_idx)
|
||||
@@ -208,11 +233,26 @@ def _compute_slot_mappings_kernel(
|
||||
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // block_size
|
||||
block_indices = positions // virtual_block_size
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_state_idx * block_table_stride + block_indices
|
||||
)
|
||||
slot_ids = block_numbers * block_size + positions % block_size
|
||||
virtual_block_offsets = positions - block_indices * virtual_block_size
|
||||
|
||||
# determine whether the token is stored on this CP rank.
|
||||
is_local = (
|
||||
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
|
||||
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
|
||||
# mapping virture block offsets to local block offsets.
|
||||
local_block_offsets = (
|
||||
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
|
||||
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
|
||||
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
|
||||
)
|
||||
|
||||
# physical slot index
|
||||
slot_ids = block_numbers * block_size + local_block_offsets
|
||||
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
@@ -17,6 +18,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
prepare_dcp_local_seq_lens,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
@@ -257,6 +259,23 @@ def prepare_inputs_to_capture(
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
try:
|
||||
dcp_group = get_dcp_group()
|
||||
dcp_world_size = dcp_group.world_size
|
||||
dcp_rank = dcp_group.rank_in_group
|
||||
except AssertionError:
|
||||
dcp_world_size = 1
|
||||
dcp_rank = 0
|
||||
if dcp_world_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_buffers.seq_lens,
|
||||
num_reqs,
|
||||
dcp_size=dcp_world_size,
|
||||
dcp_rank=dcp_rank,
|
||||
cp_kv_cache_interleave_size=block_tables.cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
@@ -275,5 +294,6 @@ def prepare_inputs_to_capture(
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -27,6 +27,10 @@ class InputBuffers:
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||
# DCP: per-request local seq_lens buffer
|
||||
self.dcp_local_seq_lens = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_pp_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
@@ -24,6 +25,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
@@ -31,6 +33,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
init_kv_cache,
|
||||
prepare_dcp_local_seq_lens,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
@@ -248,11 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
cp_kv_cache_interleave_size=(
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
),
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
self.kv_cache_config, self.vllm_config, self.device
|
||||
)
|
||||
check_attention_cp_compatibility(self.vllm_config)
|
||||
if self.do_spec_decode:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
@@ -294,6 +301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
input_batch.slot_mappings = slot_mappings_by_layer
|
||||
@@ -627,6 +635,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
dcp_size = self.parallel_config.decode_context_parallel_size
|
||||
if dcp_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
self.input_buffers.dcp_local_seq_lens,
|
||||
seq_lens,
|
||||
num_reqs,
|
||||
dcp_size=dcp_size,
|
||||
dcp_rank=get_dcp_group().rank_in_group,
|
||||
cp_kv_cache_interleave_size=(
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
),
|
||||
)
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
@@ -674,6 +695,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
|
||||
Reference in New Issue
Block a user