[P/D][V1] KV Connector API V1 (#15960)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: remi <remi@mistral.ai>
Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
Yihua Cheng
2025-04-17 15:22:40 -05:00
committed by GitHub
parent 0377b8310b
commit 3408e47159
24 changed files with 1377 additions and 83 deletions

View File

@@ -9,11 +9,12 @@ import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -110,7 +111,7 @@ class Worker(WorkerBase):
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
@@ -285,12 +286,13 @@ class Worker(WorkerBase):
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
@@ -299,6 +301,8 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.