diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index f1b95df00..d568ccf1c 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -133,10 +133,11 @@ def init_kv_cache( kv_cache_config: KVCacheConfig, attn_backends: dict[str, AttentionBackend], device: torch.device, -) -> None: +) -> dict[str, torch.Tensor]: kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) + return kv_caches def build_attn_metadata( diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py new file mode 100644 index 000000000..940068851 --- /dev/null +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + kv_transfer_state, +) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.forward_context import ( + get_forward_context, + is_forward_context_available, + set_forward_context, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class KVConnector: + """KVConnector interface used by GPUModelRunner.""" + + def pre_forward(self, scheduler_output: "SchedulerOutput") -> None: + pass + + def post_forward( + self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True + ) -> KVConnectorOutput | None: + return None + + def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + return EMPTY_MODEL_RUNNER_OUTPUT + + def set_disabled(self, disabled: bool) -> None: + pass + + +class ActiveKVConnector(KVConnector): + def __init__( + self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor] + ): + self.vllm_config = vllm_config + self.kv_connector = get_kv_transfer_group() + # Register kv caches with KV Connector if applicable. + # TODO: support cross_layers_kv_cache + # (see https://github.com/vllm-project/vllm/pull/27743) + self.kv_connector.register_kv_caches(kv_caches_dict) + self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks) + + self._disabled = False + + def pre_forward(self, scheduler_output: "SchedulerOutput") -> None: + if self._disabled: + return + + if scheduler_output.preempted_req_ids: + self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids) + assert scheduler_output.kv_connector_metadata is not None + self.kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata + ) + # TODO: sort out KV Connectors' use of forward_context + if is_forward_context_available(): + self.kv_connector.start_load_kv(get_forward_context()) + else: + with set_forward_context(None, self.vllm_config): + self.kv_connector.start_load_kv(get_forward_context()) + + def post_forward( + self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True + ) -> KVConnectorOutput | None: + if self._disabled: + return None + + output = KVConnectorOutput() + if wait_for_save: + self.kv_connector.wait_for_save() + output.finished_sending, output.finished_recving = ( + self.kv_connector.get_finished(scheduler_output.finished_req_ids) + ) + output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors() + output.kv_connector_stats = self.kv_connector.get_kv_connector_stats() + output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events() + self.kv_connector.clear_connector_metadata() + return output + + def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + if self._disabled: + return EMPTY_MODEL_RUNNER_OUTPUT + + self.pre_forward(scheduler_output) + kv_connector_output = self.post_forward(scheduler_output, wait_for_save=False) + if kv_connector_output is None or kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output + + def set_disabled(self, disabled: bool) -> None: + # Ensure that layer-wise connector hooks aren't called when disabled. + kv_transfer_state._KV_CONNECTOR_AGENT = None if disabled else self.kv_connector + self._disabled = disabled + + +NO_OP_KV_CONNECTOR = KVConnector() + + +def get_kv_connector( + vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor] +) -> KVConnector: + if not has_kv_transfer_group(): + # No-op connector. + return NO_OP_KV_CONNECTOR + + return ActiveKVConnector(vllm_config, kv_caches_dict) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ae62fdc4d..8b100da28 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -20,10 +20,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, - ModelRunnerOutput, -) +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -48,6 +45,11 @@ from vllm.v1.worker.gpu.input_batch import ( prepare_pos_seq_lens, prepare_prefill_inputs, ) +from vllm.v1.worker.gpu.kv_connector import ( + NO_OP_KV_CONNECTOR, + KVConnector, + get_kv_connector, +) from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.sample.output import SamplerOutput @@ -57,13 +59,12 @@ from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker -from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin logger = init_logger(__name__) -class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): +class GPUModelRunner(LoRAModelRunnerMixin): def __init__( self, vllm_config: VllmConfig, @@ -172,6 +173,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.tmp_cu_num_logits = UvaBufferPool(self.max_num_reqs + 1, torch.int32) self.tmp_query_start_loc = UvaBufferPool(self.max_num_reqs + 1, torch.int32) + self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR + def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len self.req_states.max_model_len = max_model_len @@ -248,13 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) self.kv_caches: list[torch.Tensor] = [] - init_kv_cache( + kv_caches_dict = init_kv_cache( self.kv_caches, self.compilation_config.static_forward_context, self.kv_cache_config, self.attn_backends, self.device, ) + self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) + # Attention groups are not supported. self.attn_groups = [] # type: ignore @@ -291,18 +296,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens_per_request[-1] += num_tokens % num_reqs assert sum(num_tokens_per_request) == num_tokens num_scheduled_tokens = { - f"_dummy_req_{i}": num_tokens_per_request[i] for i in range(num_reqs) + f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request) } dummy_scheduler_output = SchedulerOutput.make_empty() dummy_scheduler_output.total_num_scheduled_tokens = num_tokens dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens + # Disable any use of KVConnector for dummy runs. + self.kv_connector.set_disabled(True) + # Execute the model. self.execute_model( dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn ) + self.kv_connector.set_disabled(False) assert self.execute_model_state is not None - hidden_states, input_batch = self.execute_model_state + hidden_states, input_batch, _ = self.execute_model_state sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -792,7 +801,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.block_tables.apply_staged_writes() if scheduler_output.total_num_scheduled_tokens == 0: # No need to run the model. - return EMPTY_MODEL_RUNNER_OUTPUT + empty_output = self.kv_connector.no_forward(scheduler_output) + return empty_output # Get the CUDA graph size. None means no CUDA graph is used. cudagraph_size = self.cudagraph_manager.get_cudagraph_size( @@ -809,7 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) if num_tokens_after_padding == 0: # All DP ranks have zero tokens to run. - return EMPTY_MODEL_RUNNER_OUTPUT + empty_output = self.kv_connector.no_forward(scheduler_output) + return empty_output if not dummy_run: # Common case. @@ -860,6 +871,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Run CUDA graph. # NOTE(woosuk): Here, we don't need to pass the input tensors, # because they are already copied to the CUDA graph input buffers. + self.kv_connector.pre_forward(scheduler_output) hidden_states = self.cudagraph_manager.run( input_batch.num_tokens_after_padding ) @@ -877,13 +889,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ): + self.kv_connector.pre_forward(scheduler_output) hidden_states = self.model( input_ids=input_batch.input_ids, positions=positions, inputs_embeds=input_batch.inputs_embeds, ) - self.execute_model_state = hidden_states, input_batch + kv_connector_output = self.kv_connector.post_forward(scheduler_output) + self.execute_model_state = hidden_states, input_batch, kv_connector_output return None @torch.inference_mode() @@ -892,7 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): grammar_output: GrammarOutput | None, ) -> AsyncOutput | ModelRunnerOutput: assert self.execute_model_state is not None - hidden_states, input_batch = self.execute_model_state + hidden_states, input_batch, kv_connector_output = self.execute_model_state self.execute_model_state = None # type: ignore sampler_output, num_sampled, num_rejected = self.sample( @@ -917,6 +931,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, sampled_token_ids=None, # type: ignore prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + kv_connector_output=kv_connector_output, ) async_output = AsyncOutput( model_runner_output=model_runner_output, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index e39221af1..db4cb45e2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -24,6 +24,7 @@ from vllm.distributed import ( from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.kv_transfer import ( ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group, ) @@ -921,8 +922,9 @@ class Worker(WorkerBase): ) def shutdown(self) -> None: - if runner := getattr(self, "model_runner", None): - runner.ensure_kv_transfer_shutdown() + # has_kv_transfer_group can be None during interpreter shutdown. + if ensure_kv_transfer_shutdown is not None: + ensure_kv_transfer_shutdown() if self.profiler is not None: self.profiler.shutdown()