[P/D] NIXL Integration (#17751)
Signed-off-by: ApostaC <yihua98@uchicago.edu> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Brent Salisbury <bsalisbu@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Brent Salisbury <bsalisbu@redhat.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
@@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -1065,15 +1067,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
@@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
output = self.model(
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = output
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = output
|
||||
hidden_states = model_output
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
@@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def generate_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
@@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self),
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
|
||||
Reference in New Issue
Block a user