[Model Runner V2] Add KV Connector support (#32742)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -133,10 +133,11 @@ def init_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
return kv_caches
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
|
||||
125
vllm/v1/worker/gpu/kv_connector.py
Normal file
125
vllm/v1/worker/gpu/kv_connector.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
kv_transfer_state,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.forward_context import (
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
set_forward_context,
|
||||
)
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class KVConnector:
|
||||
"""KVConnector interface used by GPUModelRunner."""
|
||||
|
||||
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
pass
|
||||
|
||||
def post_forward(
|
||||
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
) -> KVConnectorOutput | None:
|
||||
return None
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ActiveKVConnector(KVConnector):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_connector = get_kv_transfer_group()
|
||||
# Register kv caches with KV Connector if applicable.
|
||||
# TODO: support cross_layers_kv_cache
|
||||
# (see https://github.com/vllm-project/vllm/pull/27743)
|
||||
self.kv_connector.register_kv_caches(kv_caches_dict)
|
||||
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
self._disabled = False
|
||||
|
||||
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if scheduler_output.preempted_req_ids:
|
||||
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
self.kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata
|
||||
)
|
||||
# TODO: sort out KV Connectors' use of forward_context
|
||||
if is_forward_context_available():
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
else:
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
def post_forward(
|
||||
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
) -> KVConnectorOutput | None:
|
||||
if self._disabled:
|
||||
return None
|
||||
|
||||
output = KVConnectorOutput()
|
||||
if wait_for_save:
|
||||
self.kv_connector.wait_for_save()
|
||||
output.finished_sending, output.finished_recving = (
|
||||
self.kv_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
|
||||
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
return output
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
if self._disabled:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
self.pre_forward(scheduler_output)
|
||||
kv_connector_output = self.post_forward(scheduler_output, wait_for_save=False)
|
||||
if kv_connector_output is None or kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
# Ensure that layer-wise connector hooks aren't called when disabled.
|
||||
kv_transfer_state._KV_CONNECTOR_AGENT = None if disabled else self.kv_connector
|
||||
self._disabled = disabled
|
||||
|
||||
|
||||
NO_OP_KV_CONNECTOR = KVConnector()
|
||||
|
||||
|
||||
def get_kv_connector(
|
||||
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
) -> KVConnector:
|
||||
if not has_kv_transfer_group():
|
||||
# No-op connector.
|
||||
return NO_OP_KV_CONNECTOR
|
||||
|
||||
return ActiveKVConnector(vllm_config, kv_caches_dict)
|
||||
@@ -20,10 +20,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
@@ -48,6 +45,11 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.kv_connector import (
|
||||
NO_OP_KV_CONNECTOR,
|
||||
KVConnector,
|
||||
get_kv_connector,
|
||||
)
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
@@ -57,13 +59,12 @@ from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -172,6 +173,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.tmp_cu_num_logits = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
|
||||
self.tmp_query_start_loc = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
|
||||
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
@@ -248,13 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
init_kv_cache(
|
||||
kv_caches_dict = init_kv_cache(
|
||||
self.kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_cache_config,
|
||||
self.attn_backends,
|
||||
self.device,
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
# Attention groups are not supported.
|
||||
self.attn_groups = [] # type: ignore
|
||||
|
||||
@@ -291,18 +296,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
assert sum(num_tokens_per_request) == num_tokens
|
||||
num_scheduled_tokens = {
|
||||
f"_dummy_req_{i}": num_tokens_per_request[i] for i in range(num_reqs)
|
||||
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
|
||||
}
|
||||
dummy_scheduler_output = SchedulerOutput.make_empty()
|
||||
dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
|
||||
dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens
|
||||
|
||||
# Disable any use of KVConnector for dummy runs.
|
||||
self.kv_connector.set_disabled(True)
|
||||
|
||||
# Execute the model.
|
||||
self.execute_model(
|
||||
dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn
|
||||
)
|
||||
self.kv_connector.set_disabled(False)
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, input_batch = self.execute_model_state
|
||||
hidden_states, input_batch, _ = self.execute_model_state
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
@@ -792,7 +801,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.block_tables.apply_staged_writes()
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
# No need to run the model.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
empty_output = self.kv_connector.no_forward(scheduler_output)
|
||||
return empty_output
|
||||
|
||||
# Get the CUDA graph size. None means no CUDA graph is used.
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
|
||||
@@ -809,7 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
if num_tokens_after_padding == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
empty_output = self.kv_connector.no_forward(scheduler_output)
|
||||
return empty_output
|
||||
|
||||
if not dummy_run:
|
||||
# Common case.
|
||||
@@ -860,6 +871,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Run CUDA graph.
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.cudagraph_manager.run(
|
||||
input_batch.num_tokens_after_padding
|
||||
)
|
||||
@@ -877,13 +889,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=input_batch.inputs_embeds,
|
||||
)
|
||||
|
||||
self.execute_model_state = hidden_states, input_batch
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = hidden_states, input_batch, kv_connector_output
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -892,7 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
grammar_output: GrammarOutput | None,
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, input_batch = self.execute_model_state
|
||||
hidden_states, input_batch, kv_connector_output = self.execute_model_state
|
||||
self.execute_model_state = None # type: ignore
|
||||
|
||||
sampler_output, num_sampled, num_rejected = self.sample(
|
||||
@@ -917,6 +931,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
sampled_token_ids=None, # type: ignore
|
||||
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
async_output = AsyncOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.distributed import (
|
||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
@@ -921,8 +922,9 @@ class Worker(WorkerBase):
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if runner := getattr(self, "model_runner", None):
|
||||
runner.ensure_kv_transfer_shutdown()
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if ensure_kv_transfer_shutdown is not None:
|
||||
ensure_kv_transfer_shutdown()
|
||||
if self.profiler is not None:
|
||||
self.profiler.shutdown()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user